#!/usr/bin/env python3
from __future__ import annotations

import argparse
import json
import os
import shlex
import subprocess
import sys
import time
from pathlib import Path
from typing import Any

BASE_DIR = Path(__file__).resolve().parent.parent
SESSIONS_DIR = Path(os.environ.get("PI_SESSION_DIR", BASE_DIR / "sessions"))
JOB_DIR = Path(os.environ.get("PI_SESSION_MEMORY_JOB_DIR", SESSIONS_DIR / "session-memory-jobs"))
STATE_FILE = Path(os.environ.get("PI_SESSION_MEMORY_SWEEP_STATE", BASE_DIR / "runtime" / "session-memory-sweeper-state.json"))
ANALYZER = Path(os.environ.get("PI_SESSION_MEMORY_ANALYZER", BASE_DIR / "scripts" / "session-memory-analyzer.py"))
LOOKBACK_DAYS = int(os.environ.get("PI_SESSION_MEMORY_LOOKBACK_DAYS", "21"))
MAX_SESSIONS_PER_RUN = int(os.environ.get("PI_SESSION_MEMORY_MAX_SESSIONS_PER_RUN", "12"))
MIN_TOTAL_MESSAGES = int(os.environ.get("PI_SESSION_MEMORY_MIN_TOTAL_MESSAGES", "6"))
MIN_ASSISTANT_TEXT = int(os.environ.get("PI_SESSION_MEMORY_MIN_ASSISTANT_TEXT", "120"))
TELEGRAM_NOTIFY_BIN = Path(os.environ.get("TELEGRAM_NOTIFY_BIN", str(Path.home() / ".agents" / "skills" / "telegram-notify" / "telegram-notify")))
MAINTENANCE_SCRIPT = Path(os.environ.get("AGENTS_DB_MAINTENANCE_SCRIPT", "/home/sebas/agents-database/scripts/run-maintenance-daemon.sh"))
DB_PATH = os.environ.get("DB_PATH", "/home/sebas/agents-database/data/shared-agent-memory.sqlite3")

SKIP_PREFIXES = (
    "you are iteration ",
    "continue the bounded ralph-loop",
)


def load_state() -> dict[str, Any]:
    if not STATE_FILE.exists():
        return {"processed": {}, "last_run_at": None}
    try:
        return json.loads(STATE_FILE.read_text(encoding="utf-8"))
    except Exception:
        return {"processed": {}, "last_run_at": None}


def save_state(state: dict[str, Any]) -> None:
    STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
    STATE_FILE.write_text(json.dumps(state, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")


def iter_session_files() -> list[Path]:
    cutoff = time.time() - LOOKBACK_DAYS * 86400
    paths = [p for p in SESSIONS_DIR.rglob("*.jsonl") if p.is_file() and p.stat().st_mtime >= cutoff]
    return sorted(paths, key=lambda p: p.stat().st_mtime)


def parse_session(path: Path) -> dict[str, Any]:
    session_id = None
    cwd = None
    branch: list[dict[str, Any]] = []
    user_messages = 0
    assistant_messages = 0
    last_user_text = ""
    last_assistant_text = ""
    first_user_text = ""

    with path.open("r", encoding="utf-8", errors="ignore") as handle:
        for line in handle:
            try:
                entry = json.loads(line)
            except Exception:
                continue
            if entry.get("type") == "session":
                session_id = entry.get("id") or session_id
                cwd = entry.get("cwd") or cwd
                continue
            if entry.get("type") != "message":
                continue
            message = entry.get("message") or {}
            role = message.get("role")
            if role not in {"user", "assistant"}:
                continue
            text_parts = []
            for part in message.get("content") or []:
                if part.get("type") == "text" and isinstance(part.get("text"), str):
                    text_parts.append(part["text"])
            text = "\n".join(text_parts).strip()
            branch.append(entry)
            if role == "user":
                user_messages += 1
                if text and not first_user_text:
                    first_user_text = text
                if text:
                    last_user_text = text
            elif role == "assistant":
                assistant_messages += 1
                if text:
                    last_assistant_text = text

    return {
        "session_id": session_id,
        "session_file": str(path),
        "cwd": cwd or str(Path.home()),
        "scope": infer_scope(cwd or str(Path.home())),
        "reason": "periodic_sweep",
        "trigger": "periodic_sweeper",
        "branch": branch,
        "user_messages": user_messages,
        "assistant_messages": assistant_messages,
        "last_user_text": truncate(last_user_text, 2000),
        "last_assistant_text": truncate(last_assistant_text, 2000),
        "first_user_text": truncate(first_user_text, 400),
    }


def infer_scope(cwd: str) -> str:
    home = str(Path.home())
    try:
        resolved = str(Path(cwd).resolve())
    except Exception:
        resolved = cwd
    if resolved == home or resolved.startswith(str(Path(home) / ".pi")):
        return "global"
    return "project"


def truncate(text: str, max_len: int) -> str:
    text = " ".join(text.split())
    if len(text) <= max_len:
        return text
    return text[: max_len - 1] + "…"


def should_skip(meta: dict[str, Any]) -> tuple[bool, str | None]:
    if meta["user_messages"] <= 0 or meta["assistant_messages"] <= 0:
        return True, "no_dialogue"
    first = (meta.get("first_user_text") or "").strip().lower()
    if first.startswith(SKIP_PREFIXES):
        return True, "loop_session"
    total_messages = meta["user_messages"] + meta["assistant_messages"]
    if total_messages < MIN_TOTAL_MESSAGES and len(meta.get("last_assistant_text") or "") < MIN_ASSISTANT_TEXT:
        return True, "too_small"
    return False, None


def existing_job_for_session(session_file: str) -> bool:
    for path in JOB_DIR.glob("*.json"):
        try:
            payload = json.loads(path.read_text(encoding="utf-8"))
        except Exception:
            continue
        if payload.get("payload", {}).get("session_file") == session_file:
            return True
    return False


def write_job(payload: dict[str, Any]) -> Path:
    JOB_DIR.mkdir(parents=True, exist_ok=True)
    job_id = time.strftime("%Y%m%dT%H%M%S", time.gmtime()) + f"{int(time.time_ns()%1_000_000):06d}Z"
    job_file = JOB_DIR / f"{job_id}.json"
    job = {
        "job_id": job_id,
        "queued_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        "status": "queued",
        "job_type": "session_memory_analysis",
        "payload": payload,
        "analysis": {
            "status": "queued",
            "save": None,
            "memory_id": None,
        },
    }
    job_file.write_text(json.dumps(job, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")
    return job_file


def run_analyzer(job_file: Path) -> int:
    proc = subprocess.run([sys.executable, str(ANALYZER), "--job-file", str(job_file)], capture_output=True, text=True)
    log_file = job_file.with_suffix(".log")
    log_file.write_text((proc.stdout or "") + (proc.stderr or ""), encoding="utf-8")
    return proc.returncode


def classify_agents_candidates(analysis: dict[str, Any]) -> list[str]:
    items = []
    for memory in (analysis.get("result") or {}).get("memories") or []:
        if not isinstance(memory, dict):
            continue
        if memory.get("scope") != "global":
            continue
        if memory.get("type") not in {"profile", "decision", "procedural"}:
            continue
        if float(memory.get("confidence") or 0) < 0.85 or float(memory.get("importance") or 0) < 0.8:
            continue
        text = f"{memory.get('title','')} {memory.get('content','')}".lower()
        if any(token in text for token in ["always", "default", "prefer", "should", "avoid", "never", "every conversation", "working rule"]):
            items.append(memory.get("title") or "untitled")
    return items


def maybe_send_telegram(message: str) -> None:
    if not message.strip() or not TELEGRAM_NOTIFY_BIN.exists():
        return
    subprocess.run([str(TELEGRAM_NOTIFY_BIN), message], check=False, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)


def run_maintenance_once() -> dict[str, Any] | None:
    if not MAINTENANCE_SCRIPT.exists():
        return None
    env = dict(os.environ)
    env["DB_PATH"] = DB_PATH
    proc = subprocess.run([str(MAINTENANCE_SCRIPT), "--once"], capture_output=True, text=True, env=env)
    raw = (proc.stdout or "").strip().splitlines()
    if proc.returncode != 0 or not raw:
        return None
    try:
        payload = json.loads(raw[-1])
    except Exception:
        return None
    promoted = 0
    archived = 0
    linked = 0
    conflicts = 0
    for item in payload:
        stats = item.get("stats") or {}
        promoted += int(stats.get("promoted") or 0)
        archived += int(stats.get("archived") or 0)
        linked += int(stats.get("linked") or 0)
        conflicts += int(stats.get("conflicts") or 0)
    return {"promoted": promoted, "archived": archived, "linked": linked, "conflicts": conflicts}


def main() -> int:
    parser = argparse.ArgumentParser()
    parser.add_argument("--once", action="store_true")
    args = parser.parse_args()
    state = load_state()
    processed = state.setdefault("processed", {})
    results = []
    notifications: list[str] = []

    for session_file in iter_session_files():
        if len(results) >= MAX_SESSIONS_PER_RUN:
            break
        stat = session_file.stat()
        key = str(session_file)
        marker = f"{stat.st_mtime_ns}:{stat.st_size}"
        if processed.get(key) == marker:
            continue
        meta = parse_session(session_file)
        skip, reason = should_skip(meta)
        if existing_job_for_session(key):
            processed[key] = marker
            results.append({"session_file": key, "status": "already_queued"})
            continue
        if skip:
            processed[key] = marker
            results.append({"session_file": key, "status": "skipped", "reason": reason})
            continue
        job_file = write_job(meta)
        exit_code = run_analyzer(job_file)
        processed[key] = marker
        job = json.loads(job_file.read_text(encoding="utf-8"))
        analysis = job.get("analysis") or {}
        status = "analyzed" if exit_code == 0 else "failed"
        result = {"session_file": key, "status": status, "job_file": str(job_file)}
        if status == "analyzed":
            result["saved_count"] = analysis.get("saved_count")
            agents_candidates = classify_agents_candidates(analysis)
            if analysis.get("saved_count"):
                notifications.append(f"Memory sweep saved {analysis.get('saved_count')} from {Path(key).name}")
            if agents_candidates:
                notifications.append("Possible AGENTS promotion:\n- " + "\n- ".join(agents_candidates[:4]))
        results.append(result)

    maintenance = run_maintenance_once()
    if maintenance and any(maintenance.values()):
        notifications.append(
            "Memory maintenance: "
            f"promoted {maintenance['promoted']}, "
            f"archived {maintenance['archived']}, "
            f"linked {maintenance['linked']}, "
            f"conflicts {maintenance['conflicts']}"
        )

    if notifications:
        maybe_send_telegram("\n\n".join(notifications[:6]))

    state["last_run_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
    save_state(state)
    print(json.dumps({"processed_now": len(results), "results": results, "maintenance": maintenance, "notifications": notifications}, ensure_ascii=False))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())
