package rpc

import (
	"bufio"
	"context"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"log/slog"
	"sync"
	"sync/atomic"
)

// dispatchJob describes a notification to be dispatched by the notificationDispatcher.
type dispatchJob struct {
	method string
	params json.RawMessage
	wg     *sync.WaitGroup // wg to signal after this notification is dispatched; may be nil
}

// pendingEntry holds the result or error for a pending JSON-RPC request.
type pendingEntry struct {
	result json.RawMessage
	err    error
}

const (
	// defaultNotifCapacity is the buffer size for the notification channel.
	// A value of 1024 means up to 1024 pending notifications can be buffered
	// before the sender blocks. This replaces goroutine-per-notification.
	defaultNotifCapacity = 1024

	// methodCancelRequest is the JSON-RPC special method for cancelling pending requests.
	methodCancelRequest = "$/cancel_request"
)

// Connection is a bidirectional JSON-RPC 2.0 connection over newline-delimited JSON.
// It supports sending requests (with response tracking), sending notifications,
// and receiving incoming requests and notifications via a Handler.
type Connection struct {
	scanner *bufio.Scanner
	writer  io.Writer
	handler Handler
	Logger  *slog.Logger // optional logger for diagnostics; uses slog.Default if nil

	mu      sync.Mutex
	pending map[int64]chan pendingEntry
	nextID  atomic.Int64

	writeMu sync.Mutex

	// notificationDispatcher processes notifications sequentially.
	notifCh chan dispatchJob
	wg      sync.WaitGroup

	// pendingNotifWg tracks in-flight notification dispatches. It is used by
	// callers that want to drain all notifications triggered by a request
	// before the request returns (e.g. session/prompt + streaming chunks).
	pendingNotifWg   sync.WaitGroup
	pendingNotifMu   sync.Mutex
	pendingNotifWgRef *sync.WaitGroup // wg captured at request start, drained at response

	ctx    context.Context
	cancel context.CancelFunc
}

// SetLogger configures the logger for this connection. Call SetLogger to enable
// structured logging; otherwise no logs are emitted.
func (c *Connection) SetLogger(logger *slog.Logger) {
	c.Logger = logger
}

// log returns the connection's logger, or a no-op logger if none is configured.
func (c *Connection) log() *slog.Logger {
	if c.Logger != nil {
		return c.Logger
	}
	return slog.New(slog.DiscardHandler)
}

// NewConnection creates a new JSON-RPC connection.
// The handler is called for incoming requests and notifications.
// The connection reads from reader and writes to writer.
// NewConnection starts a notification dispatcher goroutine; call Close to shut it down.
func NewConnection(ctx context.Context, handler Handler, reader io.Reader, writer io.Writer) *Connection {
	ctx, cancel := context.WithCancel(ctx)
	c := &Connection{
		scanner: bufio.NewScanner(reader),
		writer:  writer,
		handler: handler,
		Logger:  nil, // no logging by default; call SetLogger to enable
		pending: make(map[int64]chan pendingEntry),
		ctx:     ctx,
		cancel:  cancel,
		notifCh: make(chan dispatchJob, defaultNotifCapacity),
	}
	c.wg.Add(1)
	go c.notificationDispatcher()
	return c
}

// SendRequest sends a JSON-RPC request and blocks until the response is received.
// Returns the result as json.RawMessage on success, or an error on failure.
func (c *Connection) SendRequest(ctx context.Context, method string, params interface{}) (interface{}, error) {
	return c.SendRequestRaw(ctx, method, params)
}

// SendRequest sends a JSON-RPC request and decodes the response into R.
// Params can be any serializable value. Returns the decoded response on success.
func SendRequest[R any](ctx context.Context, c *Connection, method string, params any) (R, error) {
	var zero R
	result, err := c.SendRequestRaw(ctx, method, params)
	if err != nil {
		return zero, err
	}
	return DecodeResponse[R](result)
}

// SendRequestRaw is the internal request sender. It accepts a context so the
// caller (ctx-based or connection-cancelled) can interrupt the wait.
func (c *Connection) SendRequestRaw(ctx context.Context, method string, params interface{}) (interface{}, error) {
	id := c.nextID.Add(1)

	paramsJSON, err := json.Marshal(params)
	if err != nil {
		return nil, fmt.Errorf("marshal params: %w", err)
	}

	req := map[string]interface{}{
		"jsonrpc": "2.0",
		"id":      id,
		"method":  method,
	}
	if paramsJSON != nil {
		req["params"] = json.RawMessage(paramsJSON)
	}

	ch := make(chan pendingEntry, 1)

	// Capture pending notification wg so all notifications dispatched from
	// this point until the response arrives are drained before we return.
	c.pendingNotifMu.Lock()
	c.pendingNotifWg.Add(1)
	c.pendingNotifWgRef = &c.pendingNotifWg
	c.pendingNotifMu.Unlock()

	c.mu.Lock()
	c.pending[id] = ch
	c.mu.Unlock()

	if err := c.writeMessage(req); err != nil {
		c.removePending(id)
		c.pendingNotifMu.Lock()
		c.pendingNotifWg.Done()
		c.pendingNotifMu.Unlock()
		return nil, fmt.Errorf("write request: %w", err)
	}

	select {
	case entry := <-ch:
		if entry.err != nil {
			// Drain pending notifications before returning error.
			c.pendingNotifMu.Lock()
			wg := c.pendingNotifWgRef
			c.pendingNotifMu.Unlock()
			wg.Done()
			wg.Wait()
			return nil, entry.err
		}
		// Drain pending notifications (e.g. session updates) before returning.
		c.pendingNotifMu.Lock()
		wg := c.pendingNotifWgRef
		c.pendingNotifMu.Unlock()
		wg.Done()
		wg.Wait()
		return entry.result, nil
	case <-ctx.Done():
		c.removePending(id)
		c.pendingNotifMu.Lock()
		c.pendingNotifWg.Done()
		c.pendingNotifMu.Unlock()
		return nil, fmt.Errorf("request cancelled: %w", ctx.Err())
	}
}

// SendNotification sends a JSON-RPC notification (no ID, no response expected).
func (c *Connection) SendNotification(method string, params interface{}) error {
	paramsJSON, err := json.Marshal(params)
	if err != nil {
		return fmt.Errorf("marshal params: %w", err)
	}

	notif := map[string]interface{}{
		"jsonrpc": "2.0",
		"method":  method,
	}
	if paramsJSON != nil {
		notif["params"] = json.RawMessage(paramsJSON)
	}

	return c.writeMessage(notif)
}

// CancelRequest sends a $/cancel_request notification to cancel a pending request.
func (c *Connection) CancelRequest(requestID int64) error {
	return c.SendNotification(methodCancelRequest, map[string]interface{}{
		"requestId": requestID,
	})
}

// ReceiveLoop starts the main event loop that reads and processes incoming messages.
// This blocks until the connection is closed or an error occurs on the reader.
func (c *Connection) ReceiveLoop() {
	// Configure scanner limits: 1 MB initial buffer, 10 MB maximum.
	c.scanner.Buffer(make([]byte, 1024*1024), 10*1024*1024)

	lineNum := 0
	for c.scanner.Scan() {
		line := c.scanner.Bytes()
		if len(line) == 0 {
			continue
		}
		lineNum++
		c.processLine(line)
	}
	err := c.scanner.Err()
	if err != nil {
		if c.ctx.Err() != nil {
			// Context was cancelled; normal shutdown.
			return
		}
		if err == bufio.ErrTooLong {
			c.rejectPending(fmt.Errorf("message exceeds 10 MB scan limit"))
		} else {
			c.rejectPending(fmt.Errorf("receive error: %w", err))
		}
		return
	}

	// Scanner reached EOF.
	c.rejectPending(fmt.Errorf("connection closed"))
}

// Close shuts down the connection: cancels the context, rejects pending requests,
// and waits for in-flight goroutines to finish.
func (c *Connection) Close() error {
	c.cancel()
	// Close the notification channel so the dispatcher exits.
	close(c.notifCh)
	c.rejectPending(fmt.Errorf("connection closed"))
	c.wg.Wait()
	return nil
}

// --- internal ---

// processLine handles a single JSON-RPC message line.
func (c *Connection) processLine(line []byte) {
	var msg map[string]json.RawMessage
	if err := json.Unmarshal(line, &msg); err != nil {
		return
	}

	_, hasID := msg["id"]
	_, hasMethod := msg["method"]

	switch {
	case hasMethod && hasID:
		// Incoming request (has method and ID).
		c.wg.Add(1)
		go c.handleIncomingRequest(msg)
	case hasMethod:
		// Incoming notification (has method, no ID).
		methodRaw, _ := msg["method"]
		var method string
		json.Unmarshal(methodRaw, &method)
		var params json.RawMessage
		if p, ok := msg["params"]; ok {
			params = p
		}
		// Capture the current wg pointer so the dispatcher signals the correct
		// wg even if a new request resets pendingNotifWgRef in the meantime.
		c.pendingNotifMu.Lock()
		wg := c.pendingNotifWgRef
		if wg != nil {
			wg.Add(1)
		}
		c.pendingNotifMu.Unlock()
		select {
		case c.notifCh <- dispatchJob{method: method, params: params, wg: wg}:
			// dispatched
		default:
			c.log().Warn("notification channel full, dropping notification", "method", method)
			// Undo the Add since we didn't enqueue.
			if wg != nil {
				wg.Done()
			}
		}
	case hasID:
		// Incoming response (has ID, no method).
		c.handleResponse(msg)
	}
}

// notificationDispatcher processes incoming notifications sequentially.
// This replaces the previous goroutine-per-notification model.
func (c *Connection) notificationDispatcher() {
	defer c.wg.Done()
	for {
		job, ok := <-c.notifCh
		if !ok {
			return
		}
		// Signal that this notification has been dispatched. Each job carries
		// the wg pointer captured at the time the notification was enqueued, so
		// there is no race with concurrent requests resetting pendingNotifWgRef.
		if job.wg != nil {
			job.wg.Done()
		}
		c.dispatchNotification(job)
	}
}

// dispatchNotification calls the handler for a single notification.
// It recovers from panics and handles $/cancel_request specially.
func (c *Connection) dispatchNotification(job dispatchJob) {
	defer func() {
		if r := recover(); r != nil {
			c.log().Error("notification handler panic", "panic", r)
		}
	}()

	if job.method == methodCancelRequest {
		c.handleCancelRequest(job.params)
		return
	}
	c.handler(job.method, job.params, false)
}

// handleCancelRequest cancels a pending request identified by requestId.
func (c *Connection) handleCancelRequest(params json.RawMessage) {
	var req struct {
		RequestID int64 `json:"requestId"`
	}
	if err := json.Unmarshal(params, &req); err != nil {
		return
	}
	ch := c.removePending(req.RequestID)
	if ch != nil {
		ch <- pendingEntry{err: fmt.Errorf("request cancelled by peer")}
	}
}

// handleResponse matches an incoming response to a pending request.
func (c *Connection) handleResponse(msg map[string]json.RawMessage) {
	idRaw, _ := msg["id"]
	var id int64
	if err := json.Unmarshal(idRaw, &id); err != nil {
		return
	}

	ch := c.removePending(id)
	if ch == nil {
		return
	}

	// Check for error response.
	if errRaw, hasErr := msg["error"]; hasErr {
		var rpcErr RPCError
		if err := json.Unmarshal(errRaw, &rpcErr); err != nil {
			ch <- pendingEntry{err: fmt.Errorf("unmarshal error: %w", err)}
			return
		}
		ch <- pendingEntry{err: &rpcErr}
		return
	}

	// Success response.
	resultRaw, _ := msg["result"]
	ch <- pendingEntry{result: resultRaw}
}

// handleIncomingRequest processes an incoming JSON-RPC request.
func (c *Connection) handleIncomingRequest(msg map[string]json.RawMessage) {
	defer c.wg.Done()
	defer func() {
		if r := recover(); r != nil {
			idRaw, _ := msg["id"]
			c.sendError(idRaw, CoerceError(fmt.Errorf("handler panic: %v", r)))
		}
	}()

	methodRaw, _ := msg["method"]
	var method string
	json.Unmarshal(methodRaw, &method)

	var params json.RawMessage
	if p, ok := msg["params"]; ok {
		params = p
	}

	result, err := c.handler(method, params, true)
	idRaw, _ := msg["id"]
	if err != nil {
		c.sendError(idRaw, CoerceError(err))
	} else {
		c.sendResult(idRaw, result)
	}
}

// sendResult sends a successful JSON-RPC response.
func (c *Connection) sendResult(idRaw json.RawMessage, result interface{}) {
	resp := map[string]interface{}{
		"jsonrpc": "2.0",
		"id":      json.RawMessage(idRaw),
		"result":  result,
	}
	c.writeMessage(resp)
}

// sendError sends a JSON-RPC error response.
func (c *Connection) sendError(idRaw json.RawMessage, err error) {
	resp := map[string]interface{}{
		"jsonrpc": "2.0",
		"id":      json.RawMessage(idRaw),
	}
	if rpcErr, ok := err.(*RPCError); ok {
		resp["error"] = rpcErr
	} else {
		resp["error"] = map[string]interface{}{
			"code":    -32603,
			"message": err.Error(),
		}
	}
	c.writeMessage(resp)
}

// writeMessage marshals and writes a message to the writer.
// Protected by writeMu for thread-safe writes.
func (c *Connection) writeMessage(msg interface{}) error {
	data, err := json.Marshal(msg)
	if err != nil {
		return err
	}
	data = append(data, '\n')

	c.writeMu.Lock()
	_, err = c.writer.Write(data)
	c.writeMu.Unlock()
	return err
}

// removePending removes a pending request by ID and returns its channel.
// Returns nil if the ID was not found.
func (c *Connection) removePending(id int64) chan pendingEntry {
	c.mu.Lock()
	defer c.mu.Unlock()
	ch, ok := c.pending[id]
	if ok {
		delete(c.pending, id)
	}
	return ch
}

// rejectPending rejects all pending requests with the given error.
func (c *Connection) rejectPending(err error) {
	c.mu.Lock()
	defer c.mu.Unlock()
	for id, ch := range c.pending {
		ch <- pendingEntry{err: err}
		delete(c.pending, id)
	}
}

// --- Error types ---

// RPCError represents a JSON-RPC error response.
type RPCError struct {
	Code    int         `json:"code"`
	Message string      `json:"message"`
	Data    interface{} `json:"data,omitempty"`
}

// Error implements the error interface.
func (e *RPCError) Error() string {
	return fmt.Sprintf("RPC error %d: %s", e.Code, e.Message)
}

// NewRPCError creates a new RPC error.
func NewRPCError(code int, message string) *RPCError {
	return &RPCError{Code: code, Message: message}
}

// Error codes for ACP extensions to JSON-RPC 2.0.
const (
	CodeRequestCancelled = -32800 // Request was cancelled by the client or server.
	CodeRequestTimeout   = -32801 // Request timed out.
)

// CoerceError converts a Go error to the appropriate JSON-RPC error type.
// It maps standard Go errors to their correct JSON-RPC error codes:
//
//   - context.Canceled          → -32800 (RequestCancelled)
//   - context.DeadlineExceeded → -32801 (RequestTimeout)
//   - *RPCError                → pass through with its original code
//   - other errors             → -32603 (InternalError)
func CoerceError(err error) error {
	if err == nil {
		return nil
	}
	if rpcErr, ok := err.(*RPCError); ok {
		return rpcErr
	}
	if errors.Is(err, context.Canceled) {
		return &RPCError{Code: CodeRequestCancelled, Message: err.Error()}
	}
	if errors.Is(err, context.DeadlineExceeded) {
		return &RPCError{Code: CodeRequestTimeout, Message: err.Error()}
	}
	return &RPCError{Code: CodeInternalError, Message: err.Error()}
}
