package rpc

import (
	"bytes"
	"context"
	"encoding/json"
	"io"
	"testing"
	"time"

	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
)

func TestRPCError(t *testing.T) {
	err := NewRPCError(-32601, "Method not found")
	assert.Equal(t, "RPC error -32601: Method not found", err.Error())
	assert.Equal(t, -32601, err.Code)
}

func TestRouterLookup(t *testing.T) {
	router := NewRouter()
	router.HandleRequest("test/method", func(params json.RawMessage) (interface{}, error) {
		return map[string]string{"ok": "true"}, nil
	})

	route, ok := router.Lookup("test/method")
	require.True(t, ok)
	assert.Equal(t, "test/method", route.Method)
	assert.True(t, route.IsRequest)

	_, ok = router.Lookup("nonexistent")
	assert.False(t, ok)
}

func TestRouterMakeHandler(t *testing.T) {
	router := NewRouter()
	router.HandleRequest("add", func(params json.RawMessage) (interface{}, error) {
		var p struct{ A, B int }
		json.Unmarshal(params, &p)
		return p.A + p.B, nil
	})

	handler := router.MakeHandler()
	result, err := handler("add", json.RawMessage(`{"A":2,"B":3}`), true)
	require.NoError(t, err)
	// Direct function call returns int, not float64
	assert.Equal(t, 5, result)
}

func TestRouterMethodNotFound(t *testing.T) {
	router := NewRouter()
	handler := router.MakeHandler()
	_, err := handler("nonexistent", nil, true)
	require.Error(t, err)
	rpcErr, ok := err.(*RPCError)
	require.True(t, ok)
	assert.Equal(t, -32601, rpcErr.Code)
}

func TestConnectionClose(t *testing.T) {
	r, w := io.Pipe()
	defer r.Close()
	defer w.Close()

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	conn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, r, w)

	err := conn.Close()
	require.NoError(t, err)
}

func TestConnectionSendNotification(t *testing.T) {
	var buf bytes.Buffer

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	conn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, nil, &buf)
	defer conn.Close()

	err := conn.SendNotification("test/notif", map[string]string{"hello": "world"})
	require.NoError(t, err)

	var msg map[string]interface{}
	require.NoError(t, json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &msg))
	assert.Equal(t, "2.0", msg["jsonrpc"])
	assert.Equal(t, "test/notif", msg["method"])
}

func TestConnectionRequestResponse(t *testing.T) {
	// Cross-connected pipes for bidirectional communication
	serverReadEnd, clientWriteEnd := io.Pipe()
	clientReadEnd, serverWriteEnd := io.Pipe()

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	// Server handles "echo" method
	serverConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		if method == "echo" {
			var p struct {
				Msg string `json:"msg"`
			}
			json.Unmarshal(params, &p)
			return map[string]string{"echo": p.Msg}, nil
		}
		return nil, NewRPCError(-32601, "Method not found")
	}, serverReadEnd, serverWriteEnd)
	defer serverConn.Close()
	go serverConn.ReceiveLoop()

	// Client sends requests
	clientConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, clientReadEnd, clientWriteEnd)
	defer clientConn.Close()
	go clientConn.ReceiveLoop()

	// Send request and get response
	result, err := SendRequest[map[string]any](ctx, clientConn, "echo", map[string]any{"msg": "hello"})
	require.NoError(t, err)
	require.NotNil(t, result)
	assert.Equal(t, "hello", result["echo"])

	// Close pipes
	serverReadEnd.Close()
	serverWriteEnd.Close()
	clientReadEnd.Close()
	clientWriteEnd.Close()
}

func TestConnectionSendRequestGeneric(t *testing.T) {
	serverReadEnd, clientWriteEnd := io.Pipe()
	clientReadEnd, serverWriteEnd := io.Pipe()

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	serverConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		if method == "getUser" {
			return map[string]any{"id": float64(42), "name": "Alice"}, nil
		}
		return nil, NewRPCError(-32601, "Method not found")
	}, serverReadEnd, serverWriteEnd)
	defer serverConn.Close()
	go serverConn.ReceiveLoop()

	clientConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, clientReadEnd, clientWriteEnd)
	defer clientConn.Close()
	go clientConn.ReceiveLoop()

	result, err := SendRequest[map[string]any](ctx, clientConn, "getUser", nil)
	require.NoError(t, err)
	assert.Equal(t, float64(42), result["id"])
	assert.Equal(t, "Alice", result["name"])

	serverReadEnd.Close()
	serverWriteEnd.Close()
	clientReadEnd.Close()
	clientWriteEnd.Close()
}

func TestConnectionNotificationBounded(t *testing.T) {
	// Verify that when the notification channel is full, SendNotification does not block.
	// We test this by filling the channel past its capacity and ensuring SendNotification
	// returns immediately (non-blocking drop).
	var buf bytes.Buffer

	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	conn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, nil, &buf)
	defer conn.Close()

	// Fill the channel past capacity. The dispatcher drains sequentially, so we
	// need to send faster than it drains. Since the dispatcher isn't running its
	// receive loop, it won't drain — so we should see drops.
	// Send defaultNotifCapacity + 100 notifications.
	deadline := time.After(500 * time.Millisecond)
	for i := 0; i < defaultNotifCapacity+100; i++ {
		select {
		case <-deadline:
			t.Fatal("SendNotification blocked — bounded channel not working")
		default:
			// Non-blocking: should not block even when channel is full
			err := conn.SendNotification("ping", map[string]int{"i": i})
			assert.NoError(t, err)
		}
	}
	// If we got here without blocking, the test passes — SendNotification is non-blocking.
}

func TestConnectionCancelRequest(t *testing.T) {
	serverReadEnd, clientWriteEnd := io.Pipe()
	clientReadEnd, serverWriteEnd := io.Pipe()

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	// Server delays response indefinitely — simulates long-running request
	serverConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		if method == "slow" {
			<-ctx.Done() // block until context cancelled
			return nil, ctx.Err()
		}
		return nil, nil
	}, serverReadEnd, serverWriteEnd)
	defer serverConn.Close()
	go serverConn.ReceiveLoop()

	clientConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, clientReadEnd, clientWriteEnd)
	defer clientConn.Close()
	go clientConn.ReceiveLoop()

	// Client sends a request with a short timeout
	reqCtx, reqCancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
	defer reqCancel()

	_, err := SendRequest[any](reqCtx, clientConn, "slow", nil)
	// Should return context deadline exceeded
	assert.Error(t, err)
	assert.Contains(t, err.Error(), "context deadline exceeded")

	serverReadEnd.Close()
	serverWriteEnd.Close()
	clientReadEnd.Close()
	clientWriteEnd.Close()
}

func TestCoerceError(t *testing.T) {
	tests := []struct {
		name         string
		input        error
		wantCode     int
		wantPassThru bool
	}{
		{
			name:         "nil",
			input:        nil,
			wantCode:     0,
			wantPassThru: false,
		},
		{
			name:         "context_canceled",
			input:        context.Canceled,
			wantCode:     CodeRequestCancelled,
			wantPassThru: false,
		},
		{
			name:         "context_deadline_exceeded",
			input:        context.DeadlineExceeded,
			wantCode:     CodeRequestTimeout,
			wantPassThru: false,
		},
		{
			name:         "rpc_error_passthru",
			input:        NewRPCError(-32601, "Method not found"),
			wantCode:     -32601,
			wantPassThru: true,
		},
		{
			name:         "arbitrary_error",
			input:        assert.AnError,
			wantCode:     CodeInternalError,
			wantPassThru: false,
		},
	}

	for _, tc := range tests {
		t.Run(tc.name, func(t *testing.T) {
			result := CoerceError(tc.input)
			if tc.input == nil {
				assert.Nil(t, result)
				return
			}
			rpcErr, ok := result.(*RPCError)
			require.True(t, ok)
			assert.Equal(t, tc.wantCode, rpcErr.Code)
		})
	}
}

func TestConnectionReturnsCorrectErrorCode(t *testing.T) {
	// Integration test: verify that handler errors are coerced and sent over the wire.
	serverReadEnd, clientWriteEnd := io.Pipe()
	clientReadEnd, serverWriteEnd := io.Pipe()

	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
	defer cancel()

	serverConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		if method == "cancel_me" {
			return nil, context.Canceled
		}
		if method == "timeout_me" {
			return nil, context.DeadlineExceeded
		}
		return nil, NewRPCError(-32601, "unknown")
	}, serverReadEnd, serverWriteEnd)
	defer serverConn.Close()
	go serverConn.ReceiveLoop()

	clientConn := NewConnection(ctx, func(method string, params json.RawMessage, isRequest bool) (interface{}, error) {
		return nil, nil
	}, clientReadEnd, clientWriteEnd)
	defer clientConn.Close()
	go clientConn.ReceiveLoop()

	t.Run("canceled", func(t *testing.T) {
		reqCtx, reqCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
		defer reqCancel()
		_, err := SendRequest[any](reqCtx, clientConn, "cancel_me", nil)
		require.Error(t, err)
		rpcErr, ok := err.(*RPCError)
		require.True(t, ok)
		assert.Equal(t, CodeRequestCancelled, rpcErr.Code)
	})

	t.Run("deadline_exceeded", func(t *testing.T) {
		reqCtx, reqCancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
		defer reqCancel()
		_, err := SendRequest[any](reqCtx, clientConn, "timeout_me", nil)
		require.Error(t, err)
		rpcErr, ok := err.(*RPCError)
		require.True(t, ok)
		assert.Equal(t, CodeRequestTimeout, rpcErr.Code)
	})
}
