// Package schemagen provides utilities for parsing JSON Schema and generating Go code.
package schemagen

import (
	"encoding/json"
	"fmt"
	"os"
	"strings"
	"unicode"
)

// Schema represents a parsed JSON Schema document.
type Schema struct {
	Defs map[string]*Def
}

// Load reads a JSON Schema file and parses it.
func Load(path string) (*Schema, error) {
	data, err := os.ReadFile(path)
	if err != nil {
		return nil, fmt.Errorf("read file: %w", err)
	}
	var raw map[string]json.RawMessage
	if err := json.Unmarshal(data, &raw); err != nil {
		return nil, fmt.Errorf("parse JSON: %w", err)
	}

	s := &Schema{Defs: make(map[string]*Def)}

	// Parse $defs
	if defsRaw, ok := raw["$defs"]; ok {
		var defs map[string]json.RawMessage
		if err := json.Unmarshal(defsRaw, &defs); err != nil {
			return nil, fmt.Errorf("parse $defs: %w", err)
		}
		for name, defRaw := range defs {
			def, err := parseDef(name, defRaw)
			if err != nil {
				return nil, fmt.Errorf("parse def %q: %w", name, err)
			}
			s.Defs[name] = def
		}
	}

	// Resolve all refs and classify
	for name, def := range s.Defs {
		s.resolveDef(def)
		s.classifyDef(def)
		_ = name
	}

	return s, nil
}

// Def represents a parsed schema definition.
type Def struct {
	Name        string
	Description string
	Kind        DefKind

	// For object types
	Properties map[string]*Property
	Required   []string
	Extends    *Ref // allOf base type

	// For enum types
	EnumValues []string // for "enum": [...]
	OneOfConst []string // for "oneOf"/"anyOf" + "const" (string values)

	// For anyOf enum-like types (ErrorCode has integer consts, use title as value)
	AnyOfEnumVals []string // title values used when OneOfConst is empty

	// For discriminated unions
	OneOf         []*Ref    // branch refs
	Discriminator string    // discriminator field name (e.g. "type", "sessionUpdate")
	Variants      []Variant // extracted from oneOf branches

	// For simple types
	Type string

	// Additional
	AnyOf []*Ref // for anyOf patterns (e.g. McpServer, AuthMethod)

	// additionalProperties: true (map[string]any type alias)
	AdditionalPropertiesAny bool
}

// Variant represents one branch of a discriminated union.
type Variant struct {
	Ref        *Ref   // reference to the sub-type
	FieldName  string // Go field name (e.g. "Text", "Http")
	DiscrimVal string // discriminator value (e.g. "text", "http")
}

// Ref represents a $ref to another definition.
type Ref struct {
	Name    string          // e.g. "TextContent" from "#/$defs/TextContent"
	RawJSON json.RawMessage // raw JSON of the anyOf/oneOf branch item (for extracting title)
}

// Property represents a field in an object type.
type Property struct {
	Name        string
	Description string
	Type        *TypeRef
	Required    bool
}

// TypeRef represents the Go type for a property.
type TypeRef struct {
	GoType    string // e.g. "string", "int", "TextContent"
	IsSlice   bool   // []T
	IsMap     bool   // map[string]T
	IsPtr     bool   // pointer
	IsAny     bool   // any (interface{})
	ArrayItem *TypeRef
	MapKey    *TypeRef
	MapValue  *TypeRef
}

// DefKind classifies the kind of definition.
type DefKind int

const (
	DefKindUnknown DefKind = iota
	DefKindObject          // regular object with properties
	DefKindEnum            // string enum (oneOf + const or enum array)
	DefKindUnion           // discriminated union (oneOf with common discriminator)
	DefKindSimple          // simple type (string, integer, etc.)
	DefKindRef             // just a reference to another def
)

// String returns the string representation of DefKind.
func (k DefKind) String() string {
	switch k {
	case DefKindObject:
		return "object"
	case DefKindEnum:
		return "enum"
	case DefKindUnion:
		return "union"
	case DefKindSimple:
		return "simple"
	case DefKindRef:
		return "ref"
	default:
		return "unknown"
	}
}

func parseDef(name string, raw json.RawMessage) (*Def, error) {
	var m map[string]json.RawMessage
	if err := json.Unmarshal(raw, &m); err != nil {
		return nil, err
	}

	d := &Def{Name: name}

	// Description
	if desc, ok := m["description"]; ok {
		json.Unmarshal(desc, &d.Description)
	}

	// Simple types
	if t, ok := m["type"]; ok {
		json.Unmarshal(t, &d.Type)
	}

	// Properties
	if props, ok := m["properties"]; ok {
		d.Properties = make(map[string]*Property)
		var rawProps map[string]json.RawMessage
		if err := json.Unmarshal(props, &rawProps); err == nil {
			for propName, propRaw := range rawProps {
				prop, err := parseProperty(propName, propRaw)
				if err != nil {
					continue
				}
				d.Properties[propName] = prop
			}
		}
	}

	// Required
	if req, ok := m["required"]; ok {
		json.Unmarshal(req, &d.Required)
	}

	// enum values (direct array)
	if enum, ok := m["enum"]; ok {
		json.Unmarshal(enum, &d.EnumValues)
	}

	// allOf (base type extension)
	if allOf, ok := m["allOf"]; ok {
		d.parseAllOf(allOf)
	}

	// oneOf
	if oneOf, ok := m["oneOf"]; ok {
		d.parseOneOf(oneOf)
	}

	// anyOf
	if anyOf, ok := m["anyOf"]; ok {
		d.parseAnyOf(anyOf)
	}

	// additionalProperties: true (bare true, not a sub-schema)
	// This marks the type as map[string]any (e.g. Meta)
	if addProps, ok := m["additionalProperties"]; ok {
		var raw any
		if json.Unmarshal(addProps, &raw) == nil {
			if b, ok := raw.(bool); ok && b {
				d.AdditionalPropertiesAny = true
			}
		}
	}

	return d, nil
}

func (d *Def) parseAllOf(raw json.RawMessage) error {
	var items []json.RawMessage
	if err := json.Unmarshal(raw, &items); err != nil {
		return err
	}
	for _, item := range items {
		var m map[string]json.RawMessage
		if err := json.Unmarshal(item, &m); err != nil {
			continue
		}
		if ref, ok := m["$ref"]; ok {
			var refStr string
			if err := json.Unmarshal(ref, &refStr); err == nil {
				d.Extends = parseRef(refStr)
			}
		}
	}
	return nil
}

func (d *Def) parseOneOf(raw json.RawMessage) error {
	var items []json.RawMessage
	if err := json.Unmarshal(raw, &items); err != nil {
		return err
	}
	for _, item := range items {
		var m map[string]json.RawMessage
		if err := json.Unmarshal(item, &m); err != nil {
			continue
		}

		// Root-level const (simple enum variants like { "const": "pending" })
		if c, ok := m["const"]; ok {
			var constVal string
			if err := json.Unmarshal(c, &constVal); err == nil {
				d.OneOfConst = append(d.OneOfConst, constVal)
			}
		}

		// Const in properties (discriminator-based union)
		if props, ok := m["properties"]; ok {
			var rawProps map[string]json.RawMessage
			if err := json.Unmarshal(props, &rawProps); err == nil {
				for propName, propRaw := range rawProps {
					var prop map[string]json.RawMessage
					if err := json.Unmarshal(propRaw, &prop); err != nil {
						continue
					}
					if c, ok := prop["const"]; ok {
						var constVal string
						if err := json.Unmarshal(c, &constVal); err == nil {
							d.OneOfConst = append(d.OneOfConst, constVal)
							if d.Discriminator == "" {
								d.Discriminator = propName
							}
						}
					}
				}
			}
		}

		// Extract $ref from allOf
		if allOf, ok := m["allOf"]; ok {
			var allOfItems []json.RawMessage
			if err := json.Unmarshal(allOf, &allOfItems); err == nil {
				for _, ao := range allOfItems {
					var aoMap map[string]json.RawMessage
					if err := json.Unmarshal(ao, &aoMap); err == nil {
						if ref, ok := aoMap["$ref"]; ok {
							var refStr string
							if err := json.Unmarshal(ref, &refStr); err == nil {
								d.OneOf = append(d.OneOf, parseRef(refStr))
							}
						}
					}
				}
			}
		}
	}
	return nil
}

func (d *Def) parseAnyOf(raw json.RawMessage) error {
	var items []json.RawMessage
	if err := json.Unmarshal(raw, &items); err != nil {
		return err
	}
	for _, item := range items {
		var m map[string]json.RawMessage
		if err := json.Unmarshal(item, &m); err != nil {
			continue
		}

		// Root-level const (enum values or discriminator)
		if c, ok := m["const"]; ok {
			var constVal string
			if err := json.Unmarshal(c, &constVal); err == nil {
				d.OneOfConst = append(d.OneOfConst, constVal)
				// If no explicit discriminator field set, use "type" as default
				if d.Discriminator == "" {
					d.Discriminator = "type"
				}
			}
		}

		// title field: use as enum value when const is non-string (e.g. ErrorCode integer codes)
		// Also used as discriminator value for anyOf unions (e.g. EmbeddedResourceResource)
		if title, ok := m["title"]; ok {
			var titleStr string
			if err := json.Unmarshal(title, &titleStr); err == nil {
				// Use title as enum value if const wasn't a string
				if len(d.OneOfConst) == 0 || d.OneOfConst[len(d.OneOfConst)-1] == "" {
					// const was not a string, use title
					d.AnyOfEnumVals = append(d.AnyOfEnumVals, titleStr)
				}
				if d.Discriminator == "" {
					d.Discriminator = "type"
				}
			}
		}

		// Check for const discriminator in properties
		if props, ok := m["properties"]; ok {
			var rawProps map[string]json.RawMessage
			if err := json.Unmarshal(props, &rawProps); err == nil {
				for propName, propRaw := range rawProps {
					var prop map[string]json.RawMessage
					if err := json.Unmarshal(propRaw, &prop); err == nil {
						if c, ok := prop["const"]; ok {
							var constVal string
							if err := json.Unmarshal(c, &constVal); err == nil {
								if d.Discriminator == "" {
									d.Discriminator = propName
								}
							}
						}
					}
				}
			}
		}

		// Extract $ref from allOf inside anyOf item
		if allOf, ok := m["allOf"]; ok {
			var allOfItems []json.RawMessage
			if err := json.Unmarshal(allOf, &allOfItems); err == nil {
				for _, ao := range allOfItems {
					var aoMap map[string]json.RawMessage
					if err := json.Unmarshal(ao, &aoMap); err == nil {
						if refRaw, ok := aoMap["$ref"]; ok {
							var refStr string
							if err := json.Unmarshal(refRaw, &refStr); err == nil {
								ref := parseRef(refStr)
								ref.RawJSON = item // carry raw JSON for title extraction
								d.AnyOf = append(d.AnyOf, ref)
							}
						}
					}
				}
			}
		}
	}
	return nil
}

func parseProperty(name string, raw json.RawMessage) (*Property, error) {
	var m map[string]json.RawMessage
	if err := json.Unmarshal(raw, &m); err != nil {
		return nil, err
	}
	p := &Property{Name: name}

	if desc, ok := m["description"]; ok {
		json.Unmarshal(desc, &p.Description)
	}

	p.Type = parseTypeRef(m)
	return p, nil
}

func parseTypeRef(m map[string]json.RawMessage) *TypeRef {
	tr := &TypeRef{}

	// Handle additionalProperties BEFORE type/anyOf (takes priority)
	// e.g. _meta: { additionalProperties: true, type: ["object","null"] }
	if addProps, ok := m["additionalProperties"]; ok {
		var addMap map[string]json.RawMessage
		if json.Unmarshal(addProps, &addMap) == nil {
			if _, ok := addMap["type"]; ok {
				// { additionalProperties: { type: ... } }
				tr.IsMap = true
				tr.MapKey = &TypeRef{GoType: "string"}
				tr.MapValue = parseTypeRef(addMap)
			} else if ref, ok := addMap["$ref"]; ok {
				var refStr string
				if json.Unmarshal(ref, &refStr) == nil {
					r := parseRef(refStr)
					tr.IsMap = true
					tr.MapKey = &TypeRef{GoType: "string"}
					tr.MapValue = &TypeRef{GoType: r.Name, IsPtr: true}
				}
			} else {
				// additionalProperties: true — maps to map[string]any (Meta type)
				tr.IsMap = true
				tr.MapKey = &TypeRef{GoType: "string"}
				tr.MapValue = &TypeRef{IsAny: true}
			}
		}
	}

	// $ref (must be checked before type/anyOf since defs may have both)
	if ref, ok := m["$ref"]; ok {
		var refStr string
		if json.Unmarshal(ref, &refStr) == nil {
			r := parseRef(refStr)
			tr.GoType = r.Name
			tr.IsPtr = true
		}
		return tr
	}

	// allOf with $ref (e.g. sessionId: { allOf: [{ $ref: "SessionId" }] })
	if allOf, ok := m["allOf"]; ok {
		var allOfItems []json.RawMessage
		if json.Unmarshal(allOf, &allOfItems) == nil {
			for _, ao := range allOfItems {
				var aoMap map[string]json.RawMessage
				if json.Unmarshal(ao, &aoMap) == nil {
					if ref, ok := aoMap["$ref"]; ok {
						var refStr string
						if json.Unmarshal(ref, &refStr) == nil {
							r := parseRef(refStr)
							tr.GoType = r.Name
							tr.IsPtr = true
							return tr
						}
					}
				}
			}
		}
	}

	// anyOf (nullable or union type)
	if anyOf, ok := m["anyOf"]; ok {
		var items []json.RawMessage
		if json.Unmarshal(anyOf, &items) == nil {
			hasNull := false
			var actualType string
			for _, item := range items {
				var itemMap map[string]json.RawMessage
				if json.Unmarshal(item, &itemMap) != nil {
					continue
				}
				if ref, ok := itemMap["$ref"]; ok {
					var refStr string
					if json.Unmarshal(ref, &refStr) == nil {
						r := parseRef(refStr)
						tr.GoType = r.Name
						tr.IsPtr = true
						return tr
					}
				}
				if t, ok := itemMap["type"]; ok {
					var typeStr string
					if json.Unmarshal(t, &typeStr) == nil {
						if typeStr == "null" {
							hasNull = true
						} else {
							actualType = typeStr
						}
					}
				}
			}
			tr.GoType = actualType
			tr.IsPtr = hasNull
		}
		return tr
	}

	// Simple type (including nullable: type: ["string", "null"])
	if t, ok := m["type"]; ok {
		// Check if it's an array of types (nullable)
		var typeArr []any
		if json.Unmarshal(t, &typeArr) == nil {
			// Array of types like ["string", "null"]
			hasNull := false
			var actualType string
			for _, v := range typeArr {
				if s, ok := v.(string); ok {
					if s == "null" {
						hasNull = true
					} else {
						actualType = s
					}
				}
			}
			tr.GoType = actualType
			tr.IsPtr = hasNull
		} else {
			// Single type string
			var typeStr string
			if json.Unmarshal(t, &typeStr) == nil {
				tr.GoType = typeStr
			}
		}
	}

	// items (array)
	if items, ok := m["items"]; ok {
		tr.IsSlice = true
		var itemMap map[string]json.RawMessage
		if json.Unmarshal(items, &itemMap) == nil {
			if ref, ok := itemMap["$ref"]; ok {
				var refStr string
				if json.Unmarshal(ref, &refStr) == nil {
					r := parseRef(refStr)
					tr.ArrayItem = &TypeRef{GoType: r.Name, IsPtr: true}
				}
			} else {
				tr.ArrayItem = parseTypeRef(itemMap)
			}
		}
	}

	return tr
}

func parseRef(refStr string) *Ref {
	parts := strings.Split(refStr, "/")
	return &Ref{Name: parts[len(parts)-1]}
}

// resolveDef resolves $ref references within a def.
func (s *Schema) resolveDef(d *Def) {
	// Properties
	for _, prop := range d.Properties {
		s.resolveTypeRef(prop.Type)
	}
}

// resolveTypeRef resolves a type reference.
func (s *Schema) resolveTypeRef(tr *TypeRef) {
	if tr == nil {
		return
	}
	if tr.ArrayItem != nil {
		s.resolveTypeRef(tr.ArrayItem)
	}
	if tr.MapValue != nil {
		s.resolveTypeRef(tr.MapValue)
	}
}

// classifyDef determines the kind of a definition.
func (s *Schema) classifyDef(d *Def) {
	// Discriminated union: oneOf with multiple branches + discriminator
	if len(d.OneOf) > 1 && d.Discriminator != "" {
		d.Kind = DefKindUnion
		s.extractVariants(d)
		return
	}

	// anyOf with const or title values: enum if no discriminator (e.g. ErrorCode, SessionConfigOptionCategory)
	// Skip if it's actually a union (has AnyOf refs with discriminator)
	if len(d.AnyOf) > 0 && len(d.OneOfConst) > 0 {
		d.Kind = DefKindEnum
		return
	}

	// anyOf with single branch: discriminated union if has discriminator, else resolve to ref
	if len(d.AnyOf) == 1 && d.Type == "" && len(d.Properties) == 0 {
		if d.Discriminator != "" {
			// Has discriminator: emit as discriminated union
			d.Kind = DefKindUnion
			s.extractAnyOfVariants(d)
		} else {
			// No discriminator: resolve to the ref type
			d.Kind = DefKindRef
		}
		return
	}

	// anyOf with multiple branches: discriminated union only if branches have refs
	// anyOf without refs (e.g. RequestId: null|int|string) → DefKindSimple (any type)
	if len(d.AnyOf) > 1 {
		// Check if any branch has a $ref
		hasRef := false
		for _, r := range d.AnyOf {
			if r != nil {
				hasRef = true
				break
			}
		}
		if hasRef && d.Discriminator != "" {
			d.Kind = DefKindUnion
			s.extractAnyOfVariants(d)
			return
		}
		// No refs or no discriminator: emit as DefKindSimple (any type)
		d.Kind = DefKindSimple
		return
	}

	// Enum: oneOf + const values (no properties)
	if len(d.OneOfConst) > 0 && len(d.Properties) == 0 {
		d.Kind = DefKindEnum
		return
	}

	// Direct enum array
	if len(d.EnumValues) > 0 {
		d.Kind = DefKindEnum
		return
	}

	// Object with properties
	if len(d.Properties) > 0 || d.Extends != nil {
		d.Kind = DefKindObject
		return
	}

	// Simple type
	if d.Type != "" && d.Type != "object" {
		d.Kind = DefKindSimple
		return
	}

	// additionalProperties: true (e.g. Meta → map[string]any)
	if d.AdditionalPropertiesAny {
		d.Kind = DefKindSimple
		return
	}

	d.Kind = DefKindObject
}

// extractVariants extracts variant information from oneOf branches.
func (s *Schema) extractVariants(d *Def) {
	// Get the discriminator field name
	discField := d.Discriminator
	if discField == "" {
		discField = "type"
	}

	for i, ref := range d.OneOf {
		if ref == nil {
			continue
		}
		sub := s.Defs[ref.Name]
		if sub == nil {
			continue
		}

		// Find the discriminator value
		discVal := ""
		// Try to find in the raw schema... for now use OneOfConst
		if i < len(d.OneOfConst) {
			discVal = d.OneOfConst[i]
		}

		// The variant field name comes from the sub-type name or discriminator value
		fieldName := sub.Name
		if discVal != "" {
			fieldName = DiscrimToFieldName(discVal)
		}

		d.Variants = append(d.Variants, Variant{
			Ref:        ref,
			FieldName:  fieldName,
			DiscrimVal: discVal,
		})
	}
}

// extractAnyOfVariants extracts variants from anyOf branches.
func (s *Schema) extractAnyOfVariants(d *Def) {
	discField := d.Discriminator
	if discField == "" {
		discField = "type"
	}

	for _, ref := range d.AnyOf {
		if ref == nil {
			continue
		}

		// Try to get title from raw JSON stored on the ref
		titleVal := ""
		if len(ref.RawJSON) > 0 {
			var item map[string]json.RawMessage
			if json.Unmarshal(ref.RawJSON, &item) == nil {
				if title, ok := item["title"]; ok {
					json.Unmarshal(title, &titleVal)
				}
			}
		}

		// Try to get title or infer from ref name
		fieldName := ref.Name

		// For McpServer: infer from ref name (HttpMcpServer -> Http)
		if strings.HasSuffix(ref.Name, "McpServer") {
			fieldName = strings.TrimSuffix(ref.Name, "McpServer")
		} else if strings.HasSuffix(ref.Name, "AuthMethod") {
			fieldName = strings.TrimSuffix(ref.Name, "AuthMethod")
		}

		// Try to get discriminator value from sub-schema properties
		discVal := inferDiscrimVal(ref.Name, discField)

		// If title was available, use it as discriminator value
		if titleVal != "" && discVal == strings.ToLower(ref.Name) {
			discVal = strings.ToLower(titleVal)
		}

		d.Variants = append(d.Variants, Variant{
			Ref:        ref,
			FieldName:  fieldName,
			DiscrimVal: discVal,
		})
	}
}

// inferDiscrimVal tries to infer the discriminator value from the ref name.
func inferDiscrimVal(refName, discField string) string {
	lower := strings.ToLower(refName)
	switch {
	case strings.HasSuffix(lower, "mcpserverhttp") || strings.HasSuffix(lower, "http"):
		return "http"
	case strings.HasSuffix(lower, "mcpserversse") || strings.HasSuffix(lower, "sse"):
		return "sse"
	case strings.HasSuffix(lower, "mcpserverstdio") || strings.HasSuffix(lower, "stdio"):
		return "stdio"
	case strings.HasSuffix(lower, "authmethodagent") || strings.HasSuffix(lower, "agent"):
		return "agent"
	case strings.HasSuffix(lower, "authmethodenvvar"):
		return "env_var"
	case strings.HasSuffix(lower, "authmethodterminal"):
		return "terminal"
	default:
		// Fallback: title-case the ref name and use as field
		return strings.ToLower(refName)
	}
}

// DiscrimToFieldName converts a discriminator value to a Go field name.
// e.g. "user_message_chunk" -> "UserMessageChunk", "text" -> "Text"
func DiscrimToFieldName(discrim string) string {
	// Handle special cases where the discriminator value is the same as the type name
	lower := strings.ToLower(discrim)
	if lower == "text" || lower == "image" || lower == "audio" || lower == "resource_link" || lower == "resource" {
		return strings.Title(discrim)
	}
	// Default: title-case each word
	words := strings.Split(discrim, "_")
	for i, w := range words {
		if len(w) > 0 {
			words[i] = string(unicode.ToUpper(rune(w[0]))) + w[1:]
		}
	}
	return strings.Join(words, "")
}

// ToGoType converts a JSON Schema type to a Go type.
func ToGoType(schemaType string) string {
	switch schemaType {
	case "string":
		return "string"
	case "integer":
		return "int"
	case "number":
		return "float64"
	case "boolean":
		return "bool"
	case "object":
		return "struct{}" // placeholder, will be resolved
	case "array":
		return "[]any"
	case "null":
		return ""
	default:
		return "any"
	}
}

// GoTypeName returns the full Go type name for a TypeRef.
func (tr *TypeRef) GoTypeName() string {
	if tr == nil {
		return "any"
	}
	if tr.IsAny {
		return "any"
	}

	base := tr.GoType
	if base == "object" {
		base = "map[string]any"
	}

	if tr.IsSlice {
		itemType := "any"
		if tr.ArrayItem != nil {
			itemType = tr.ArrayItem.GoTypeName()
		}
		return "[]" + itemType
	}

	if tr.IsMap {
		valType := "any"
		if tr.MapValue != nil {
			valType = tr.MapValue.GoTypeName()
		}
		return "map[string]" + valType
	}

	return base
}
