#!/usr/bin/env python3
"""Analyze Cursor usage CSV and generate spending/token charts."""

from __future__ import annotations

import sys
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import matplotlib.ticker as ticker
import pandas as pd

CSV_PATH = Path(__file__).parent / "usage.csv"
OUT_DIR = Path(__file__).parent / "charts"

TOKEN_COLS = [
    "Input (w/ Cache Write)",
    "Input (w/o Cache Write)",
    "Cache Read",
    "Output Tokens",
    "Total Tokens",
]

PALETTE = [
    "#6366f1",  # indigo
    "#f59e0b",  # amber
    "#10b981",  # emerald
    "#ef4444",  # red
    "#3b82f6",  # blue
    "#8b5cf6",  # violet
    "#ec4899",  # pink
    "#14b8a6",  # teal
    "#f97316",  # orange
    "#64748b",  # slate
]


def fmt_millions(x: float, _pos: int | None = None) -> str:
    if x >= 1e9:
        return f"{x / 1e9:.1f}B"
    if x >= 1e6:
        return f"{x / 1e6:.1f}M"
    if x >= 1e3:
        return f"{x / 1e3:.0f}K"
    return f"{x:.0f}"


def short_user(email: str) -> str:
    if email == "N/A" or not isinstance(email, str):
        return "N/A"
    return email.split("@")[0].replace(".", " ").title()


def load_data(path: Path) -> pd.DataFrame:
    df = pd.read_csv(path, parse_dates=["Date"])
    for col in TOKEN_COLS:
        df[col] = pd.to_numeric(df[col], errors="coerce").fillna(0).astype(int)
    df["short_user"] = df["User"].apply(short_user)
    df["date"] = df["Date"].dt.date
    df["hour"] = df["Date"].dt.floor("h")
    return df


def style_ax(ax: plt.Axes, title: str, ylabel: str = "Tokens") -> None:
    ax.set_title(title, fontsize=14, fontweight="bold", pad=12)
    ax.set_ylabel(ylabel, fontsize=10)
    ax.yaxis.set_major_formatter(ticker.FuncFormatter(fmt_millions))
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.grid(axis="y", alpha=0.3)


# ── Chart 1: Total tokens per user ─────────────────────

def chart_tokens_per_user(df: pd.DataFrame) -> None:
    grouped = (
        df.groupby("short_user")["Total Tokens"]
        .sum()
        .sort_values(ascending=True)
    )
    fig, ax = plt.subplots(figsize=(10, max(6, len(grouped) * 0.4)))
    bars = ax.barh(grouped.index, grouped.values, color=PALETTE[0], edgecolor="white", height=0.7)
    for bar, val in zip(bars, grouped.values):
        ax.text(
            bar.get_width() + grouped.max() * 0.01,
            bar.get_y() + bar.get_height() / 2,
            fmt_millions(val),
            va="center",
            fontsize=8,
        )
    style_ax(ax, "Total Tokens per User", "")
    ax.set_xlabel("Tokens", fontsize=10)
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(fmt_millions))
    fig.tight_layout()
    fig.savefig(OUT_DIR / "01_tokens_per_user.png", dpi=150)
    plt.close(fig)


# ── Chart 2: Total tokens per model ────────────────────

def chart_tokens_per_model(df: pd.DataFrame) -> None:
    grouped = (
        df.groupby("Model")["Total Tokens"]
        .sum()
        .sort_values(ascending=True)
    )
    fig, ax = plt.subplots(figsize=(10, max(5, len(grouped) * 0.45)))
    bars = ax.barh(grouped.index, grouped.values, color=PALETTE[1], edgecolor="white", height=0.7)
    for bar, val in zip(bars, grouped.values):
        ax.text(
            bar.get_width() + grouped.max() * 0.01,
            bar.get_y() + bar.get_height() / 2,
            fmt_millions(val),
            va="center",
            fontsize=8,
        )
    style_ax(ax, "Total Tokens per Model", "")
    ax.set_xlabel("Tokens", fontsize=10)
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(fmt_millions))
    fig.tight_layout()
    fig.savefig(OUT_DIR / "02_tokens_per_model.png", dpi=150)
    plt.close(fig)


# ── Chart 3: Daily usage over time (stacked by token type) ──

def chart_daily_token_breakdown(df: pd.DataFrame) -> None:
    breakdown_cols = ["Input (w/ Cache Write)", "Input (w/o Cache Write)", "Cache Read", "Output Tokens"]
    daily = df.groupby("date")[breakdown_cols].sum()
    daily.index = pd.to_datetime(daily.index)

    fig, ax = plt.subplots(figsize=(12, 6))
    colors = [PALETTE[0], PALETTE[4], PALETTE[2], PALETTE[3]]
    labels = ["Input (Cache Write)", "Input (No Cache)", "Cache Read", "Output"]
    bottom = None
    for col, color, label in zip(breakdown_cols, colors, labels):
        vals = daily[col].values
        ax.bar(daily.index, vals, bottom=bottom, color=color, label=label, width=0.7, edgecolor="white", linewidth=0.3)
        bottom = vals if bottom is None else bottom + vals

    style_ax(ax, "Daily Token Breakdown")
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d"))
    ax.xaxis.set_major_locator(mdates.DayLocator())
    ax.legend(fontsize=9, ncol=4, loc="upper left")
    fig.autofmt_xdate(rotation=30)
    fig.tight_layout()
    fig.savefig(OUT_DIR / "03_daily_token_breakdown.png", dpi=150)
    plt.close(fig)


# ── Chart 4: Daily usage per user (top N) ──────────────

def chart_daily_per_user(df: pd.DataFrame, top_n: int = 8) -> None:
    top_users = df.groupby("short_user")["Total Tokens"].sum().nlargest(top_n).index.tolist()
    subset = df[df["short_user"].isin(top_users)]
    pivot = subset.pivot_table(index="date", columns="short_user", values="Total Tokens", aggfunc="sum", fill_value=0)
    pivot.index = pd.to_datetime(pivot.index)
    pivot = pivot[top_users]

    fig, ax = plt.subplots(figsize=(12, 6))
    for i, user in enumerate(top_users):
        ax.plot(pivot.index, pivot[user], marker="o", markersize=4, color=PALETTE[i % len(PALETTE)], label=user, linewidth=2)

    style_ax(ax, f"Daily Tokens — Top {top_n} Users")
    ax.xaxis.set_major_formatter(mdates.DateFormatter("%b %d"))
    ax.xaxis.set_major_locator(mdates.DayLocator())
    ax.legend(fontsize=8, ncol=2, loc="upper left")
    fig.autofmt_xdate(rotation=30)
    fig.tight_layout()
    fig.savefig(OUT_DIR / "04_daily_per_user.png", dpi=150)
    plt.close(fig)


# ── Chart 5: Requests per user ──────────────────────────

def chart_requests_per_user(df: pd.DataFrame) -> None:
    counts = df["short_user"].value_counts().sort_values(ascending=True)
    fig, ax = plt.subplots(figsize=(10, max(6, len(counts) * 0.4)))
    bars = ax.barh(counts.index, counts.values, color=PALETTE[5], edgecolor="white", height=0.7)
    for bar, val in zip(bars, counts.values):
        ax.text(
            bar.get_width() + counts.max() * 0.01,
            bar.get_y() + bar.get_height() / 2,
            str(val),
            va="center",
            fontsize=8,
        )
    style_ax(ax, "Total Requests per User", "")
    ax.set_xlabel("Requests", fontsize=10)
    fig.tight_layout()
    fig.savefig(OUT_DIR / "05_requests_per_user.png", dpi=150)
    plt.close(fig)


# ── Chart 6: Model popularity (pie) ────────────────────

def chart_model_share(df: pd.DataFrame) -> None:
    counts = df["Model"].value_counts()
    top = counts.head(6)
    if len(counts) > 6:
        top["Other"] = counts.iloc[6:].sum()

    fig, ax = plt.subplots(figsize=(8, 8))
    wedges, texts, autotexts = ax.pie(
        top.values,
        labels=top.index,
        autopct="%1.1f%%",
        colors=PALETTE[: len(top)],
        startangle=140,
        pctdistance=0.8,
    )
    for t in autotexts:
        t.set_fontsize(9)
    ax.set_title("Model Usage Share (by request count)", fontsize=14, fontweight="bold", pad=16)
    fig.tight_layout()
    fig.savefig(OUT_DIR / "06_model_share.png", dpi=150)
    plt.close(fig)


# ── Chart 7: Hourly heatmap-style (requests per hour of day per weekday) ──

def chart_hourly_heatmap(df: pd.DataFrame) -> None:
    df_copy = df.copy()
    df_copy["weekday"] = df_copy["Date"].dt.day_name()
    df_copy["hour_of_day"] = df_copy["Date"].dt.hour
    order = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"]
    pivot = df_copy.pivot_table(index="weekday", columns="hour_of_day", values="Total Tokens", aggfunc="count", fill_value=0)
    pivot = pivot.reindex([d for d in order if d in pivot.index])

    fig, ax = plt.subplots(figsize=(14, 5))
    im = ax.imshow(pivot.values, cmap="YlOrRd", aspect="auto")
    ax.set_xticks(range(pivot.shape[1]))
    ax.set_xticklabels([f"{h}:00" for h in pivot.columns], fontsize=8, rotation=45)
    ax.set_yticks(range(pivot.shape[0]))
    ax.set_yticklabels(pivot.index, fontsize=10)
    ax.set_title("Request Activity Heatmap (weekday × hour)", fontsize=14, fontweight="bold", pad=12)
    fig.colorbar(im, ax=ax, label="Requests", shrink=0.8)
    fig.tight_layout()
    fig.savefig(OUT_DIR / "07_hourly_heatmap.png", dpi=150)
    plt.close(fig)


# ── Summary table (printed to stdout) ──────────────────

def print_summary(df: pd.DataFrame) -> None:
    total_tokens = df["Total Tokens"].sum()
    total_requests = len(df)
    date_range = f'{df["Date"].min():%Y-%m-%d %H:%M} → {df["Date"].max():%Y-%m-%d %H:%M}'
    unique_users = df[df["User"] != "N/A"]["User"].nunique()
    unique_models = df["Model"].nunique()

    print("\n" + "=" * 60)
    print("  CURSOR USAGE SUMMARY")
    print("=" * 60)
    print(f"  Period:          {date_range}")
    print(f"  Total requests:  {total_requests:,}")
    print(f"  Total tokens:    {fmt_millions(total_tokens)}")
    print(f"  Unique users:    {unique_users}")
    print(f"  Unique models:   {unique_models}")
    print(f"  Avg tokens/req:  {fmt_millions(total_tokens / max(total_requests, 1))}")
    print("=" * 60)

    print("\n  Top 5 users by total tokens:")
    top = df.groupby("short_user")["Total Tokens"].sum().nlargest(5)
    for i, (user, tokens) in enumerate(top.items(), 1):
        print(f"    {i}. {user:<25s} {fmt_millions(tokens):>8s}")

    print(f"\n  Charts saved to: {OUT_DIR.resolve()}/")
    print()


# ── Main ────────────────────────────────────────────────

def main() -> None:
    csv_path = Path(sys.argv[1]) if len(sys.argv) > 1 else CSV_PATH
    if not csv_path.exists():
        print(f"ERROR: CSV not found at {csv_path}", file=sys.stderr)
        sys.exit(1)

    OUT_DIR.mkdir(parents=True, exist_ok=True)

    print(f"Loading {csv_path} ...")
    df = load_data(csv_path)
    print(f"Loaded {len(df):,} rows")

    chart_tokens_per_user(df)
    print("  ✓ 01_tokens_per_user.png")

    chart_tokens_per_model(df)
    print("  ✓ 02_tokens_per_model.png")

    chart_daily_token_breakdown(df)
    print("  ✓ 03_daily_token_breakdown.png")

    chart_daily_per_user(df)
    print("  ✓ 04_daily_per_user.png")

    chart_requests_per_user(df)
    print("  ✓ 05_requests_per_user.png")

    chart_model_share(df)
    print("  ✓ 06_model_share.png")

    chart_hourly_heatmap(df)
    print("  ✓ 07_hourly_heatmap.png")

    print_summary(df)


if __name__ == "__main__":
    main()
