#!/usr/bin/env python3
from __future__ import annotations

import json
import math
import os
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any

SESSIONS_DIR = Path(os.environ.get("CODEX_SESSIONS_DIR", "/home/sebas/.codex/sessions"))
STATE_FILE = Path(os.environ.get("CODEX_USAGE_ALERT_STATE", "/home/sebas/runtime/codex-usage-alert/state.json"))
TELEGRAM_NOTIFY = Path(os.environ.get("CODEX_TELEGRAM_NOTIFY", "/home/sebas/.agents/skills/telegram-notify/telegram-notify"))
PRIMARY_WARN_LEVELS = tuple(int(x) for x in os.environ.get("CODEX_PRIMARY_WARN_LEVELS", "80,90,95").split(",") if x)
WEEKLY_THRESHOLDS = [
    (30, 35),
    (15, 20),
    (8, 10),
]
ACTIONABLE_PRIMARY_FREE_PERCENT = float(os.environ.get("CODEX_ACTIONABLE_PRIMARY_FREE_PERCENT", "25"))
MAX_SNAPSHOT_AGE_HOURS = float(os.environ.get("CODEX_MAX_SNAPSHOT_AGE_HOURS", "12"))


@dataclass
class Snapshot:
    ts: str
    primary_used: float
    primary_window_min: int
    primary_resets_at: int
    secondary_used: float
    secondary_window_min: int
    secondary_resets_at: int
    plan_type: str | None


def load_latest_snapshot() -> Snapshot | None:
    latest: tuple[float, Snapshot] | None = None
    if not SESSIONS_DIR.exists():
        return None
    for path in SESSIONS_DIR.rglob("*.jsonl"):
        try:
            mtime = path.stat().st_mtime
        except OSError:
            continue
        if latest and mtime < latest[0] - 86400:
            continue
        snap = load_last_snapshot_from_file(path)
        if not snap:
            continue
        if latest is None or mtime > latest[0]:
            latest = (mtime, snap)
    return latest[1] if latest else None


def load_last_snapshot_from_file(path: Path) -> Snapshot | None:
    last: Snapshot | None = None
    try:
        with path.open() as fh:
            for line in fh:
                try:
                    data = json.loads(line)
                except Exception:
                    continue
                payload = data.get("payload") or {}
                if payload.get("type") != "token_count":
                    continue
                info = payload.get("info") or {}
                rate_limits = info.get("rate_limits") or payload.get("rate_limits")
                if not rate_limits:
                    continue
                primary = rate_limits.get("primary") or {}
                secondary = rate_limits.get("secondary") or {}
                if not primary or not secondary:
                    continue
                try:
                    last = Snapshot(
                        ts=str(data.get("timestamp") or ""),
                        primary_used=float(primary.get("used_percent") or 0.0),
                        primary_window_min=int(primary.get("window_minutes") or 0),
                        primary_resets_at=int(primary.get("resets_at") or 0),
                        secondary_used=float(secondary.get("used_percent") or 0.0),
                        secondary_window_min=int(secondary.get("window_minutes") or 0),
                        secondary_resets_at=int(secondary.get("resets_at") or 0),
                        plan_type=rate_limits.get("plan_type"),
                    )
                except Exception:
                    continue
    except OSError:
        return None
    return last


def load_state() -> dict[str, Any]:
    try:
        return json.loads(STATE_FILE.read_text())
    except Exception:
        return {}


def save_state(state: dict[str, Any]) -> None:
    STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
    STATE_FILE.write_text(json.dumps(state, indent=2, sort_keys=True) + "\n")


def fmt_hours(seconds: float) -> str:
    hours = seconds / 3600
    if hours >= 10:
        return f"{round(hours):.0f}h"
    return f"{hours:.1f}h"


def build_weekly_message(s: Snapshot, hours_left: float, threshold_hours: int, threshold_remaining: int) -> str:
    weekly_free = max(0.0, 100.0 - s.secondary_used)
    primary_free = max(0.0, 100.0 - s.primary_used)
    now = datetime.now(timezone.utc).timestamp()
    seconds_left = max(0.0, s.secondary_resets_at - now)
    primary_window_sec = max(1, s.primary_window_min * 60)
    primary_resets_left = max(0, math.floor(seconds_left / primary_window_sec))
    status = "actionable_now" if primary_free >= ACTIONABLE_PRIMARY_FREE_PERCENT else "blocked_by_5h"
    return (
        f"Codex semanal desaprovechado. "
        f"Faltan {fmt_hours(hours_left * 3600)}. "
        f"Libre semanal {weekly_free:.0f}%. "
        f"Libre 5h {primary_free:.0f}%. "
        f"Resets 5h antes del cierre: {primary_resets_left}. "
        f"Estado: {status}."
    )


def build_primary_message(s: Snapshot, level: int) -> str:
    primary_free = max(0.0, 100.0 - s.primary_used)
    reset_in = max(0, s.primary_resets_at - int(datetime.now(timezone.utc).timestamp()))
    return (
        f"Codex 5h alto uso: {s.primary_used:.0f}% usado, {primary_free:.0f}% libre. "
        f"Reset en {fmt_hours(reset_in)}."
    )


def send_telegram(message: str) -> None:
    subprocess.run([str(TELEGRAM_NOTIFY), message], check=True)


def maybe_emit_alerts(s: Snapshot, state: dict[str, Any]) -> list[str]:
    sent: list[str] = []
    now = datetime.now(timezone.utc).timestamp()
    try:
        snapshot_ts = datetime.fromisoformat(s.ts.replace("Z", "+00:00")).timestamp()
    except Exception:
        snapshot_ts = now
    if now - snapshot_ts > MAX_SNAPSHOT_AGE_HOURS * 3600:
        state["stale_snapshot"] = {"ts": s.ts, "age_hours": round((now - snapshot_ts) / 3600, 1)}
        return sent
    state.pop("stale_snapshot", None)
    state.setdefault("weekly", {})
    state.setdefault("primary", {})
    weekly_reset_key = str(s.secondary_resets_at)
    primary_reset_key = str(s.primary_resets_at)

    hours_left = max(0.0, s.secondary_resets_at - now) / 3600
    weekly_free = max(0.0, 100.0 - s.secondary_used)
    primary_free = max(0.0, 100.0 - s.primary_used)

    weekly_state = state["weekly"]
    if weekly_state.get("reset") != weekly_reset_key:
        weekly_state.clear()
        weekly_state["reset"] = weekly_reset_key
        weekly_state["sent"] = []
    weekly_sent = set(weekly_state.get("sent") or [])
    for threshold_hours, threshold_remaining in sorted(WEEKLY_THRESHOLDS, key=lambda x: x[0]):
        key = f"{threshold_hours}h_{threshold_remaining}pct"
        if key in weekly_sent:
            continue
        if hours_left <= threshold_hours and weekly_free >= threshold_remaining:
            sent.append(build_weekly_message(s, hours_left, threshold_hours, threshold_remaining))
            weekly_sent.add(key)
            break
    weekly_state["sent"] = sorted(weekly_sent)
    weekly_state["last_weekly_free"] = weekly_free
    weekly_state["last_primary_free"] = primary_free

    primary_state = state["primary"]
    if primary_state.get("reset") != primary_reset_key:
        primary_state.clear()
        primary_state["reset"] = primary_reset_key
        primary_state["sent"] = []
    primary_sent = set(primary_state.get("sent") or [])
    for level in sorted(PRIMARY_WARN_LEVELS, reverse=True):
        key = str(level)
        if key in primary_sent:
            continue
        if s.primary_used >= level:
            sent.append(build_primary_message(s, level))
            primary_sent.add(key)
            break
    primary_state["sent"] = sorted(primary_sent, key=int)
    primary_state["last_primary_used"] = s.primary_used

    state["last_snapshot"] = {
        "ts": s.ts,
        "primary_used": s.primary_used,
        "primary_free": primary_free,
        "primary_resets_at": s.primary_resets_at,
        "secondary_used": s.secondary_used,
        "secondary_free": weekly_free,
        "secondary_resets_at": s.secondary_resets_at,
        "plan_type": s.plan_type,
    }
    return sent


def print_status(s: Snapshot | None, state: dict[str, Any]) -> None:
    print(json.dumps({
        "snapshot": None if s is None else {
            "ts": s.ts,
            "primary_used": s.primary_used,
            "primary_window_min": s.primary_window_min,
            "primary_resets_at": s.primary_resets_at,
            "secondary_used": s.secondary_used,
            "secondary_window_min": s.secondary_window_min,
            "secondary_resets_at": s.secondary_resets_at,
            "plan_type": s.plan_type,
        },
        "state": state,
    }, indent=2))


def main() -> int:
    dry_run = "--dry-run" in sys.argv
    status_only = "--status" in sys.argv
    s = load_latest_snapshot()
    state = load_state()
    if status_only:
        print_status(s, state)
        return 0
    if s is None:
        print("no_codex_rate_limits_found", file=sys.stderr)
        return 1
    messages = maybe_emit_alerts(s, state)
    if dry_run:
        for m in messages:
            print(m)
        save_state(state)
        return 0
    for m in messages:
        send_telegram(m)
        print(m)
    save_state(state)
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
