#!/usr/bin/env python
##> Quick plot interactive
##> Plots data from data files with minimal effort.
import argparse
import csv
import os
import re
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


# --------------------------
# CSV loading & explain
# --------------------------


import io

stdin_cache = None  # global cache for stdin content


def load_file(file):
    global stdin_cache

    if file == "-":
        if stdin_cache is None:
            # Read stdin once and cache it
            stdin_cache = sys.stdin.read()
        content = stdin_cache
        lines = content.splitlines()
        sample = content[:4096]
        fobj = io.StringIO(content)
    else:
        if not os.path.exists(file):
            sys.exit(f"Error: file '{file}' not found")
        with open(file, "r", encoding="utf-8", errors="ignore") as f:
            content = f.read()
        lines = content.splitlines()
        sample = content[:4096]
        fobj = io.StringIO(content)

    # Detect delimiter
    try:
        dialect = csv.Sniffer().sniff(sample, delimiters=[",", ";", "\t", " "])
        sep = dialect.delimiter
    except csv.Error:
        sep = ","

    # Detect header
    first_line = lines[0].strip()
    parts = re.split(rf"[{re.escape(sep)}\s]+", first_line)
    has_header = any(re.search(r"[A-Za-z]", p) for p in parts)

    # Load CSV with pandas
    if sep == " ":
        df = pd.read_csv(fobj, sep=r"\s+", header=0 if has_header else None)
    else:
        df = pd.read_csv(fobj, sep=sep, header=0 if has_header else None)

    if not has_header:
        df.columns = [f"col{i}" for i in range(df.shape[1])]

    df.attrs["has_header"] = has_header
    df.attrs["sep"] = sep
    return df


def explain(file):
    df = load_file(file)
    print(f"\n📄 File: {file}")
    print(f"→ Rows: {len(df)}, Columns: {len(df.columns)}")
    print(f"→ Separator: '{df.attrs.get('sep', ',')}'")
    print(f"→ Header detected: {df.attrs.get('has_header', False)}\n")
    print("Columns:")
    for c in df.columns:
        s = df[c]
        example = s.iloc[0] if len(s) > 0 else "(empty)"
        print(f"  {c:<15} {str(s.dtype):<10} Example: {example}")
    print()


# --------------------------
# Parsing helpers
# --------------------------


def parse_col(df, col):
    if col.isdigit():
        return df.columns[int(col)]
    return col


def parse_file_col(s):
    if ":" in s:
        col, f = s.split(":", 1)
        return col, f
    return s, None


# --------------------------
# Main plotting logic
# --------------------------


def main():
    p = argparse.ArgumentParser(description="Quick Plot for CSV files")
    p.add_argument("files", nargs="+", help="CSV files")
    p.add_argument("-x", "--x", help="X column(s), optionally file-specified")
    p.add_argument("-y", "--y", help="Y column(s), optionally file-specified")
    p.add_argument("--explain", "-e", action="store_true", help="Explain file contents")
    p.add_argument("--spread", action="store_true", help="One subplot per column")
    p.add_argument("-o", "--output", help="Save figure to file instead of showing")
    p.add_argument("--dpi", type=int, default=150, help="DPI for saved figure")
    p.add_argument("--logx", action="store_true", help="Logarithmic X axis")
    p.add_argument("--logy", action="store_true", help="Logarithmic Y axis")
    p.add_argument(
        "--non-shared", action="store_true", help="Separate Y axes for each series"
    )
    args = p.parse_args()

    if args.explain:
        for f in args.files:
            explain(f)
        return

    # Load all files
    default_file = args.files[0]
    dfs = {f: load_file(f) for f in args.files}

    # Parse x and y specs
    x_specs, y_specs = [], []
    if args.x:
        for xpart in args.x.split(","):
            x_specs.append(parse_file_col(xpart))
    if args.y:
        for ypart in args.y.split(","):
            y_specs.append(parse_file_col(ypart))

    # Default x and y
    if not x_specs:
        x_specs = [(None, None)]
    if not y_specs:
        df = dfs[default_file]
        y_specs = [(c, default_file) for c in df.columns]

    # --------------------------
    # Spread mode (subplots)
    # --------------------------
    if args.spread:
        df = dfs[default_file]
        n = len(df.columns)
        fig, axes = plt.subplots(n, 1, sharex=True, figsize=(10, 3 * n))
        if n == 1:
            axes = [axes]

        x = np.arange(len(df))
        for i, c in enumerate(df.columns):
            axes[i].plot(x, df[c], label=c)
            axes[i].set_ylabel(c)
            axes[i].margins(x=0.05, y=0.05)
            axes[i].legend()
        axes[-1].set_xlabel("Sample index")
        fig.tight_layout()

        if args.output:
            plt.savefig(args.output, dpi=args.dpi, bbox_inches="tight")
            print(f"Plot saved to {args.output}")
        else:
            plt.show()
        return

    # --------------------------
    # Normal plotting
    # --------------------------
    fig, ax = plt.subplots(figsize=(10, 5))
    axes = [ax]

    for ycol, yfile in y_specs:
        df_y = dfs[yfile or default_file]
        ycol_name = parse_col(df_y, ycol)

        # Determine x
        if x_specs[0][0] is None:
            x = np.arange(len(df_y))
        else:
            xcol, xfile = x_specs[0]
            df_x = dfs[xfile or default_file]
            xcol_name = parse_col(df_x, xcol)
            x = df_x[xcol_name]

        if args.non_shared and len(axes) > 0:
            ax = ax.twinx()
            axes.append(ax)

        ax.plot(x, df_y[ycol_name], label=ycol_name)

    # Log scale
    for a in axes:
        if args.logx:
            a.set_xscale("log")
        if args.logy:
            a.set_yscale("log")
        a.margins(x=0.05, y=0.05)

    axes[0].legend()
    fig.tight_layout()

    if args.output:
        plt.savefig(args.output, dpi=args.dpi, bbox_inches="tight")
        print(f"Plot saved to {args.output}")
    else:
        plt.show()


if __name__ == "__main__":
    main()
