diff --git a/pkg/config/latest/parse.go b/pkg/config/latest/parse.go index 4b3af7edb..814fc29c8 100644 --- a/pkg/config/latest/parse.go +++ b/pkg/config/latest/parse.go @@ -4,7 +4,7 @@ import ( "github.com/goccy/go-yaml" "github.com/docker/docker-agent/pkg/config/types" - previous "github.com/docker/docker-agent/pkg/config/v6" + previous "github.com/docker/docker-agent/pkg/config/v7" ) func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { @@ -26,34 +26,5 @@ func upgradeIfNeeded(c any, _ []byte) (any, error) { var config Config types.CloneThroughJSON(old, &config) - - // Migrate AgentConfig.RAG []string → toolsets with type: rag + ref - for i, agent := range old.Agents { - if len(agent.RAG) == 0 { - continue - } - for _, ragName := range agent.RAG { - config.Agents[i].Toolsets = append(config.Agents[i].Toolsets, Toolset{ - Type: "rag", - Ref: ragName, - }) - } - } - - // Migrate top-level RAG map from RAGConfig to RAGToolset - if len(old.RAG) > 0 && config.RAG == nil { - config.RAG = make(map[string]RAGToolset) - } - for name, oldRAG := range old.RAG { - var ragCfg RAGConfig - types.CloneThroughJSON(oldRAG, &ragCfg) - config.RAG[name] = RAGToolset{ - Toolset: Toolset{ - Type: "rag", - RAGConfig: &ragCfg, - }, - } - } - return config, nil } diff --git a/pkg/config/latest/types.go b/pkg/config/latest/types.go index 2396939a1..54929d5f8 100644 --- a/pkg/config/latest/types.go +++ b/pkg/config/latest/types.go @@ -16,7 +16,7 @@ import ( "github.com/docker/docker-agent/pkg/effort" ) -const Version = "7" +const Version = "8" // Config represents the entire configuration file type Config struct { diff --git a/pkg/config/v7/model_config_clone_test.go b/pkg/config/v7/model_config_clone_test.go new file mode 100644 index 000000000..6e40a7b53 --- /dev/null +++ b/pkg/config/v7/model_config_clone_test.go @@ -0,0 +1,79 @@ +package latest + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestModelConfig_Clone_DeepCopiesPointerFields(t *testing.T) { + t.Parallel() + + temp := 0.7 + maxTokens := int64(4096) + topP := 0.9 + parallel := true + trackUsage := true + + original := &ModelConfig{ + Provider: "openai", + Model: "gpt-4o", + Temperature: &temp, + MaxTokens: &maxTokens, + TopP: &topP, + ParallelToolCalls: ¶llel, + TrackUsage: &trackUsage, + ThinkingBudget: &ThinkingBudget{Effort: "high"}, + ProviderOpts: map[string]any{"key": "value"}, + Routing: []RoutingRule{ + {Model: "fast", Examples: []string{"quick question"}}, + }, + } + + clone := original.Clone() + + // Mutate every pointer/collection field in the original. + *original.Temperature = 0.1 + *original.MaxTokens = 1 + *original.TopP = 0.1 + *original.ParallelToolCalls = false + *original.TrackUsage = false + original.ThinkingBudget.Effort = "low" + original.ProviderOpts["key"] = "mutated" + original.Routing[0].Examples[0] = "mutated" + + // Clone must be unaffected. + assert.InDelta(t, 0.7, *clone.Temperature, 0.001) + assert.Equal(t, int64(4096), *clone.MaxTokens) + assert.InDelta(t, 0.9, *clone.TopP, 0.001) + assert.True(t, *clone.ParallelToolCalls) + assert.True(t, *clone.TrackUsage) + assert.Equal(t, "high", clone.ThinkingBudget.Effort) + assert.Equal(t, "value", clone.ProviderOpts["key"]) + assert.Equal(t, "quick question", clone.Routing[0].Examples[0]) +} + +func TestModelConfig_Clone_Nil(t *testing.T) { + t.Parallel() + + var m *ModelConfig + assert.Nil(t, m.Clone()) +} + +func TestModelConfig_Clone_MinimalFields(t *testing.T) { + t.Parallel() + + original := &ModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-5", + } + + clone := original.Clone() + + assert.Equal(t, "anthropic", clone.Provider) + assert.Equal(t, "claude-sonnet-4-5", clone.Model) + assert.Nil(t, clone.Temperature) + assert.Nil(t, clone.MaxTokens) + assert.Nil(t, clone.ProviderOpts) + assert.Nil(t, clone.Routing) +} diff --git a/pkg/config/v7/model_ref.go b/pkg/config/v7/model_ref.go new file mode 100644 index 000000000..2ab6c3c4c --- /dev/null +++ b/pkg/config/v7/model_ref.go @@ -0,0 +1,20 @@ +package latest + +import ( + "fmt" + "strings" +) + +// ParseModelRef parses an inline "provider/model" reference into a +// ModelConfig. It returns an error when the string does not contain +// exactly one "/" separator or when either part is empty. +// +// cfg, err := ParseModelRef("openai/gpt-4o") +// // cfg.Provider == "openai", cfg.Model == "gpt-4o" +func ParseModelRef(ref string) (ModelConfig, error) { + providerName, model, ok := strings.Cut(ref, "/") + if !ok || providerName == "" || model == "" { + return ModelConfig{}, fmt.Errorf("invalid model reference %q: expected 'provider/model' format", ref) + } + return ModelConfig{Provider: providerName, Model: model}, nil +} diff --git a/pkg/config/v7/parse.go b/pkg/config/v7/parse.go new file mode 100644 index 000000000..4b3af7edb --- /dev/null +++ b/pkg/config/v7/parse.go @@ -0,0 +1,59 @@ +package latest + +import ( + "github.com/goccy/go-yaml" + + "github.com/docker/docker-agent/pkg/config/types" + previous "github.com/docker/docker-agent/pkg/config/v6" +) + +func Register(parsers map[string]func([]byte) (any, error), upgraders *[]func(any, []byte) (any, error)) { + parsers[Version] = func(d []byte) (any, error) { return parse(d) } + *upgraders = append(*upgraders, upgradeIfNeeded) +} + +func parse(data []byte) (Config, error) { + var cfg Config + err := yaml.UnmarshalWithOptions(data, &cfg, yaml.Strict()) + return cfg, err +} + +func upgradeIfNeeded(c any, _ []byte) (any, error) { + old, ok := c.(previous.Config) + if !ok { + return c, nil + } + + var config Config + types.CloneThroughJSON(old, &config) + + // Migrate AgentConfig.RAG []string → toolsets with type: rag + ref + for i, agent := range old.Agents { + if len(agent.RAG) == 0 { + continue + } + for _, ragName := range agent.RAG { + config.Agents[i].Toolsets = append(config.Agents[i].Toolsets, Toolset{ + Type: "rag", + Ref: ragName, + }) + } + } + + // Migrate top-level RAG map from RAGConfig to RAGToolset + if len(old.RAG) > 0 && config.RAG == nil { + config.RAG = make(map[string]RAGToolset) + } + for name, oldRAG := range old.RAG { + var ragCfg RAGConfig + types.CloneThroughJSON(oldRAG, &ragCfg) + config.RAG[name] = RAGToolset{ + Toolset: Toolset{ + Type: "rag", + RAGConfig: &ragCfg, + }, + } + } + + return config, nil +} diff --git a/pkg/config/v7/skills_config_test.go b/pkg/config/v7/skills_config_test.go new file mode 100644 index 000000000..a7465a15f --- /dev/null +++ b/pkg/config/v7/skills_config_test.go @@ -0,0 +1,270 @@ +package latest + +import ( + "encoding/json" + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestSkillsConfig_UnmarshalYAML(t *testing.T) { + tests := []struct { + name string + input string + expected SkillsConfig + }{ + { + name: "boolean true", + input: "true", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "boolean false", + input: "false", + expected: SkillsConfig{Sources: nil}, + }, + { + name: "list with local only", + input: "[local]", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "list with remote URL", + input: "[\"http://example.com\"]", + expected: SkillsConfig{Sources: []string{"http://example.com"}}, + }, + { + name: "list with local and remote", + input: "[local, \"https://skills.example.com\"]", + expected: SkillsConfig{Sources: []string{ + "local", + "https://skills.example.com", + }}, + }, + { + name: "multiline list", + input: `- local +- https://example.com +- http://internal.corp`, + expected: SkillsConfig{Sources: []string{ + "local", + "https://example.com", + "http://internal.corp", + }}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg SkillsConfig + err := yaml.Unmarshal([]byte(tt.input), &cfg) + require.NoError(t, err) + assert.Equal(t, tt.expected, cfg) + }) + } +} + +func TestSkillsConfig_MarshalYAML(t *testing.T) { + tests := []struct { + name string + input SkillsConfig + expected string + }{ + { + name: "disabled", + input: SkillsConfig{}, + expected: "false\n", + }, + { + name: "local only marshals as true", + input: SkillsConfig{Sources: []string{"local"}}, + expected: "true\n", + }, + { + name: "list with remote", + input: SkillsConfig{Sources: []string{"local", "https://example.com"}}, + expected: "- local\n- https://example.com\n", + }, + { + name: "remote only", + input: SkillsConfig{Sources: []string{"https://example.com"}}, + expected: "- https://example.com\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := yaml.Marshal(tt.input) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestSkillsConfig_UnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + expected SkillsConfig + }{ + { + name: "boolean true", + input: "true", + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "boolean false", + input: "false", + expected: SkillsConfig{Sources: nil}, + }, + { + name: "list with local", + input: `["local"]`, + expected: SkillsConfig{Sources: []string{"local"}}, + }, + { + name: "list with remote URLs", + input: `["local", "https://skills.example.com"]`, + expected: SkillsConfig{Sources: []string{"local", "https://skills.example.com"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var cfg SkillsConfig + err := json.Unmarshal([]byte(tt.input), &cfg) + require.NoError(t, err) + assert.Equal(t, tt.expected, cfg) + }) + } +} + +func TestSkillsConfig_MarshalJSON(t *testing.T) { + tests := []struct { + name string + input SkillsConfig + expected string + }{ + { + name: "disabled", + input: SkillsConfig{}, + expected: "false", + }, + { + name: "local only as true", + input: SkillsConfig{Sources: []string{"local"}}, + expected: "true", + }, + { + name: "list with remote", + input: SkillsConfig{Sources: []string{"local", "https://example.com"}}, + expected: `["local","https://example.com"]`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := json.Marshal(tt.input) + require.NoError(t, err) + assert.Equal(t, tt.expected, string(out)) + }) + } +} + +func TestSkillsConfig_Enabled(t *testing.T) { + assert.False(t, SkillsConfig{}.Enabled()) + assert.False(t, SkillsConfig{Sources: nil}.Enabled()) + assert.False(t, SkillsConfig{Sources: []string{}}.Enabled()) + assert.True(t, SkillsConfig{Sources: []string{"local"}}.Enabled()) + assert.True(t, SkillsConfig{Sources: []string{"https://example.com"}}.Enabled()) +} + +func TestSkillsConfig_HasLocal(t *testing.T) { + assert.False(t, SkillsConfig{}.HasLocal()) + assert.False(t, SkillsConfig{Sources: []string{"https://example.com"}}.HasLocal()) + assert.True(t, SkillsConfig{Sources: []string{"local"}}.HasLocal()) + assert.True(t, SkillsConfig{Sources: []string{"local", "https://example.com"}}.HasLocal()) +} + +func TestSkillsConfig_RemoteURLs(t *testing.T) { + assert.Empty(t, SkillsConfig{}.RemoteURLs()) + assert.Empty(t, SkillsConfig{Sources: []string{"local"}}.RemoteURLs()) + assert.Equal(t, + []string{"https://example.com", "http://internal.corp"}, + SkillsConfig{Sources: []string{"local", "https://example.com", "http://internal.corp"}}.RemoteURLs(), + ) +} + +func TestSkillsConfig_JSONRoundTrip(t *testing.T) { + // This tests the upgrade path from v4 (bool) to v5 (SkillsConfig) via CloneThroughJSON + t.Run("bool true round trips through JSON", func(t *testing.T) { + jsonData := []byte("true") + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.True(t, cfg.Enabled()) + assert.True(t, cfg.HasLocal()) + assert.Equal(t, []string{"local"}, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, "true", string(out)) + }) + + t.Run("bool false round trips through JSON", func(t *testing.T) { + jsonData := []byte("false") + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.False(t, cfg.Enabled()) + assert.Nil(t, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, "false", string(out)) + }) + + t.Run("list round trips through JSON", func(t *testing.T) { + jsonData := []byte(`["local","https://example.com"]`) + var cfg SkillsConfig + require.NoError(t, json.Unmarshal(jsonData, &cfg)) + assert.True(t, cfg.Enabled()) + assert.Equal(t, []string{"local", "https://example.com"}, cfg.Sources) + + out, err := json.Marshal(cfg) + require.NoError(t, err) + assert.Equal(t, `["local","https://example.com"]`, string(out)) + }) +} + +func TestSkillsConfig_InAgentConfig(t *testing.T) { + yamlInput := ` +model: openai/gpt-4 +skills: + - local + - https://skills.example.com +toolsets: + - type: filesystem +` + var agent AgentConfig + err := yaml.Unmarshal([]byte(yamlInput), &agent) + require.NoError(t, err) + assert.True(t, agent.Skills.Enabled()) + assert.True(t, agent.Skills.HasLocal()) + assert.Equal(t, []string{"https://skills.example.com"}, agent.Skills.RemoteURLs()) +} + +func TestSkillsConfig_InAgentConfigBool(t *testing.T) { + yamlInput := ` +model: openai/gpt-4 +skills: true +toolsets: + - type: filesystem +` + var agent AgentConfig + err := yaml.Unmarshal([]byte(yamlInput), &agent) + require.NoError(t, err) + assert.True(t, agent.Skills.Enabled()) + assert.True(t, agent.Skills.HasLocal()) + assert.Empty(t, agent.Skills.RemoteURLs()) +} diff --git a/pkg/config/v7/types.go b/pkg/config/v7/types.go new file mode 100644 index 000000000..2396939a1 --- /dev/null +++ b/pkg/config/v7/types.go @@ -0,0 +1,1534 @@ +package latest + +import ( + "cmp" + "encoding/json" + "errors" + "fmt" + "maps" + "slices" + "strings" + "time" + + "github.com/goccy/go-yaml" + + "github.com/docker/docker-agent/pkg/config/types" + "github.com/docker/docker-agent/pkg/effort" +) + +const Version = "7" + +// Config represents the entire configuration file +type Config struct { + Version string `json:"version,omitempty"` + Agents Agents `json:"agents,omitempty"` + Providers map[string]ProviderConfig `json:"providers,omitempty"` + Models map[string]ModelConfig `json:"models,omitempty"` + MCPs map[string]MCPToolset `json:"mcps,omitempty"` + RAG map[string]RAGToolset `json:"rag,omitempty"` + Metadata Metadata `json:"metadata"` + Permissions *PermissionsConfig `json:"permissions,omitempty"` +} + +// MCPToolset is a reusable MCP server definition stored in the top-level +// "mcps" section. It is identical to a Toolset but skips the normal +// Toolset.validate() call during YAML unmarshaling because the "type" +// field is implicit (always "mcp") and the source (command/remote/ref) +// is validated later during config resolution. +type MCPToolset struct { + Toolset `json:",inline" yaml:",inline"` +} + +func (m *MCPToolset) UnmarshalYAML(unmarshal func(any) error) error { + // Use a plain alias to avoid triggering Toolset.UnmarshalYAML + // (which calls validate and requires "type" to be set). + type alias Toolset + var tmp alias + if err := unmarshal(&tmp); err != nil { + return err + } + m.Toolset = Toolset(tmp) + m.Type = "mcp" + return m.validate() +} + +// RAGToolset is a reusable RAG source definition stored in the top-level +// "rag" section. It is identical to a Toolset but skips the normal +// Toolset.validate() call during YAML unmarshaling because the "type" +// field is implicit (always "rag") and the RAG config is validated +// during config resolution. +type RAGToolset struct { + Toolset `json:",inline" yaml:",inline"` +} + +func (r RAGToolset) MarshalYAML() (any, error) { + // Flatten RAGConfig fields alongside toolset fields into a single map. + result := make(map[string]any) + + if r.Instruction != "" { + result["instruction"] = r.Instruction + } + if len(r.Tools) > 0 { + result["tools"] = r.Tools + } + if r.Name != "" { + result["name"] = r.Name + } + if !r.Defer.IsEmpty() { + result["defer"] = r.Defer + } + + if r.RAGConfig != nil { + cfg := r.RAGConfig + result["tool"] = cfg.Tool + if len(cfg.Docs) > 0 { + result["docs"] = cfg.Docs + } + if cfg.RespectVCS != nil { + result["respect_vcs"] = *cfg.RespectVCS + } + if len(cfg.Strategies) > 0 { + result["strategies"] = cfg.Strategies + } + result["results"] = cfg.Results + } + + return result, nil +} + +func (r *RAGToolset) UnmarshalYAML(unmarshal func(any) error) error { + // RAGToolset flattens RAGConfig fields directly at the top level, + // so users write tool/docs/strategies alongside toolset fields + // (instruction, tools, name, defer) without a rag_config wrapper. + // + // We unmarshal into a raw map first to avoid strict-mode errors + // from fields that belong to RAGConfig but not Toolset. + var raw map[string]any + if err := unmarshal(&raw); err != nil { + return err + } + + // Extract toolset-level fields + var tf Toolset + tf.Type = "rag" + if v, ok := raw["instruction"].(string); ok { + tf.Instruction = v + } + if v, ok := raw["name"].(string); ok { + tf.Name = v + } + if v, ok := raw["tools"]; ok { + if arr, ok := v.([]any); ok { + for _, item := range arr { + if s, ok := item.(string); ok { + tf.Tools = append(tf.Tools, s) + } + } + } + } + if v, ok := raw["defer"]; ok { + data, _ := yaml.Marshal(v) + _ = yaml.Unmarshal(data, &tf.Defer) + } + + // Unmarshal RAGConfig from the same map (it has its own UnmarshalYAML) + var ragCfg RAGConfig + if err := unmarshal(&ragCfg); err != nil { + return err + } + + tf.RAGConfig = &ragCfg + r.Toolset = tf + return nil +} + +type Agents []AgentConfig + +func (c *Agents) UnmarshalYAML(unmarshal func(any) error) error { + var items yaml.MapSlice + if err := unmarshal(&items); err != nil { + return err + } + + agents := make([]AgentConfig, 0, len(items)) + for _, item := range items { + name, ok := item.Key.(string) + if !ok { + return errors.New("agent name must be a string") + } + + valueBytes, err := yaml.Marshal(item.Value) + if err != nil { + return fmt.Errorf("failed to marshal agent config for %s: %w", name, err) + } + + var agent AgentConfig + if err := yaml.UnmarshalWithOptions(valueBytes, &agent, yaml.DisallowUnknownField()); err != nil { + return fmt.Errorf("failed to unmarshal agent config for %s: %w", name, err) + } + + agent.Name = name + agents = append(agents, agent) + } + + *c = agents + return nil +} + +func (c Agents) MarshalYAML() (any, error) { + mapSlice := make(yaml.MapSlice, 0, len(c)) + + for _, agent := range c { + mapSlice = append(mapSlice, yaml.MapItem{ + Key: agent.Name, + Value: agent, + }) + } + + return mapSlice, nil +} + +func (c *Agents) First() AgentConfig { + if len(*c) > 0 { + return (*c)[0] + } + panic("no agents configured") +} + +func (c *Agents) Lookup(name string) (AgentConfig, bool) { + for _, agent := range *c { + if agent.Name == name { + return agent, true + } + } + return AgentConfig{}, false +} + +func (c *Agents) Update(name string, update func(a *AgentConfig)) bool { + for i := range *c { + if (*c)[i].Name == name { + update(&(*c)[i]) + return true + } + } + return false +} + +// ProviderConfig represents a reusable provider definition. +// It allows users to define custom providers with default base URLs and token keys. +// Models can reference these providers by name, inheriting the defaults. +type ProviderConfig struct { + // APIType specifies which API schema to use. Supported values: + // - "openai_chatcompletions" (default): Use the OpenAI Chat Completions API + // - "openai_responses": Use the OpenAI Responses API + APIType string `json:"api_type,omitempty"` + // BaseURL is the base URL for the provider's API endpoint + BaseURL string `json:"base_url"` + // TokenKey is the environment variable name containing the API token + TokenKey string `json:"token_key,omitempty"` +} + +// FallbackConfig represents fallback model configuration for an agent. +// Controls which models to try when the primary fails and how retries/cooldowns work. +// Most users only need to specify Models — the defaults handle common scenarios automatically. +type FallbackConfig struct { + // Models is a list of fallback models to try in order if the primary fails. + // Each entry can be a model name from the models section or an inline provider/model format. + Models []string `json:"models,omitempty"` + // Retries is the number of retries per model with exponential backoff. + // Default is 2 (giving 3 total attempts per model). Use -1 to disable retries entirely. + // Retries only apply to retryable errors (5xx, timeouts); non-retryable errors (429, 4xx) + // skip immediately to the next model. + Retries int `json:"retries,omitempty"` + // Cooldown is the duration to stick with a successful fallback model before + // retrying the primary. Only applies after a non-retryable error (e.g., 429). + // Default is 1 minute. Use Go duration format (e.g., "1m", "30s", "2m30s"). + Cooldown Duration `json:"cooldown"` +} + +// Duration is a wrapper around time.Duration that supports YAML/JSON unmarshaling +// from string format (e.g., "1m", "30s", "2h30m"). +type Duration struct { + time.Duration +} + +// UnmarshalYAML implements custom unmarshaling for Duration from string format +func (d *Duration) UnmarshalYAML(unmarshal func(any) error) error { + if d == nil { + return errors.New("cannot unmarshal into nil Duration") + } + + var s string + if err := unmarshal(&s); err != nil { + // Try as integer (seconds) + var secs int + if err2 := unmarshal(&secs); err2 == nil { + d.Duration = time.Duration(secs) * time.Second + return nil + } + return err + } + if s == "" { + d.Duration = 0 + return nil + } + dur, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid duration format %q: %w", s, err) + } + d.Duration = dur + return nil +} + +// MarshalYAML implements custom marshaling for Duration to string format +func (d Duration) MarshalYAML() (any, error) { + if d.Duration == 0 { + return "", nil + } + return d.String(), nil +} + +// UnmarshalJSON implements custom unmarshaling for Duration from string format +func (d *Duration) UnmarshalJSON(data []byte) error { + if d == nil { + return errors.New("cannot unmarshal into nil Duration") + } + + var s string + if err := json.Unmarshal(data, &s); err != nil { + // Try as integer (seconds) + var secs int + if err2 := json.Unmarshal(data, &secs); err2 == nil { + d.Duration = time.Duration(secs) * time.Second + return nil + } + return err + } + if s == "" { + d.Duration = 0 + return nil + } + dur, err := time.ParseDuration(s) + if err != nil { + return fmt.Errorf("invalid duration format %q: %w", s, err) + } + d.Duration = dur + return nil +} + +// MarshalJSON implements custom marshaling for Duration to string format +func (d Duration) MarshalJSON() ([]byte, error) { + if d.Duration == 0 { + return json.Marshal("") + } + return json.Marshal(d.String()) +} + +// AgentConfig represents a single agent configuration +type AgentConfig struct { + Name string + Model string `json:"model,omitempty"` + Fallback *FallbackConfig `json:"fallback,omitempty"` + Description string `json:"description,omitempty"` + WelcomeMessage string `json:"welcome_message,omitempty"` + Toolsets []Toolset `json:"toolsets,omitempty"` + Instruction string `json:"instruction,omitempty"` + SubAgents []string `json:"sub_agents,omitempty"` + Handoffs []string `json:"handoffs,omitempty"` + + AddDate bool `json:"add_date,omitempty"` + AddEnvironmentInfo bool `json:"add_environment_info,omitempty"` + CodeModeTools bool `json:"code_mode_tools,omitempty"` + AddDescriptionParameter bool `json:"add_description_parameter,omitempty"` + MaxIterations int `json:"max_iterations,omitempty"` + MaxConsecutiveToolCalls int `json:"max_consecutive_tool_calls,omitempty"` + MaxOldToolCallTokens int `json:"max_old_tool_call_tokens,omitempty"` + NumHistoryItems int `json:"num_history_items,omitempty"` + AddPromptFiles []string `json:"add_prompt_files,omitempty" yaml:"add_prompt_files,omitempty"` + Commands types.Commands `json:"commands,omitempty"` + StructuredOutput *StructuredOutput `json:"structured_output,omitempty"` + Skills SkillsConfig `json:"skills,omitzero"` + Hooks *HooksConfig `json:"hooks,omitempty"` +} + +const SkillSourceLocal = "local" + +// SkillsConfig controls skill discovery sources for an agent. +// Supports three YAML formats: +// - Boolean: `skills: true` (equivalent to ["local"]) or `skills: false` (disabled) +// - List: `skills: ["local", "http://example.com"]` +// +// The special source "local" loads skills from the filesystem (standard locations). +// HTTP/HTTPS URLs load skills from remote servers per the well-known skills discovery spec. +type SkillsConfig struct { //nolint:recvcheck // MarshalYAML/MarshalJSON must use value receiver, UnmarshalYAML/UnmarshalJSON must use pointer + Sources []string +} + +func (s SkillsConfig) Enabled() bool { + return len(s.Sources) > 0 +} + +func (s SkillsConfig) HasLocal() bool { + return slices.Contains(s.Sources, SkillSourceLocal) +} + +func (s SkillsConfig) RemoteURLs() []string { + var urls []string + for _, src := range s.Sources { + if strings.HasPrefix(src, "http://") || strings.HasPrefix(src, "https://") { + urls = append(urls, src) + } + } + return urls +} + +func (s *SkillsConfig) UnmarshalYAML(unmarshal func(any) error) error { + var b bool + if err := unmarshal(&b); err == nil { + if b { + s.Sources = []string{SkillSourceLocal} + } else { + s.Sources = nil + } + return nil + } + + var sources []string + if err := unmarshal(&sources); err != nil { + return errors.New("skills must be a boolean or a list of sources") + } + s.Sources = sources + return nil +} + +func (s SkillsConfig) MarshalYAML() (any, error) { + if len(s.Sources) == 0 { + return false, nil + } + if len(s.Sources) == 1 && s.Sources[0] == SkillSourceLocal { + return true, nil + } + return s.Sources, nil +} + +func (s *SkillsConfig) UnmarshalJSON(data []byte) error { + var b bool + if err := json.Unmarshal(data, &b); err == nil { + if b { + s.Sources = []string{SkillSourceLocal} + } else { + s.Sources = nil + } + return nil + } + + var sources []string + if err := json.Unmarshal(data, &sources); err != nil { + return errors.New("skills must be a boolean or a list of sources") + } + s.Sources = sources + return nil +} + +func (s SkillsConfig) MarshalJSON() ([]byte, error) { + if len(s.Sources) == 0 { + return json.Marshal(false) + } + if len(s.Sources) == 1 && s.Sources[0] == SkillSourceLocal { + return json.Marshal(true) + } + return json.Marshal(s.Sources) +} + +// GetFallbackModels returns the fallback models from the config. +func (a *AgentConfig) GetFallbackModels() []string { + if a.Fallback != nil { + return a.Fallback.Models + } + return nil +} + +// GetFallbackRetries returns the fallback retries from the config. +func (a *AgentConfig) GetFallbackRetries() int { + if a.Fallback != nil { + return a.Fallback.Retries + } + return 0 +} + +// GetFallbackCooldown returns the fallback cooldown duration from the config. +// Returns the configured cooldown, or 0 if not set (caller should apply default). +func (a *AgentConfig) GetFallbackCooldown() time.Duration { + if a.Fallback != nil { + return a.Fallback.Cooldown.Duration + } + return 0 +} + +// ModelConfig represents the configuration for a model +type ModelConfig struct { + // Name is the manifest model name (map key), populated at runtime. + // Not serialized — set by teamloader/model_switcher when resolving models. + Name string `json:"-"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + // DisplayModel holds the original model name from the YAML config, before alias resolution. + // When set, provider.ID() returns Provider + "/" + DisplayModel instead of the resolved name. + // This ensures the UI shows the user-configured name (e.g., "claude-haiku-4-5") + // while the API uses the resolved name (e.g., "claude-haiku-4-5-20251001"). + DisplayModel string `json:"-"` + Temperature *float64 `json:"temperature,omitempty"` + MaxTokens *int64 `json:"max_tokens,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` + PresencePenalty *float64 `json:"presence_penalty,omitempty"` + BaseURL string `json:"base_url,omitempty"` + ParallelToolCalls *bool `json:"parallel_tool_calls,omitempty"` + TokenKey string `json:"token_key,omitempty"` + // ProviderOpts allows provider-specific options. + ProviderOpts map[string]any `json:"provider_opts,omitempty"` + TrackUsage *bool `json:"track_usage,omitempty"` + // ThinkingBudget controls reasoning effort/budget. + // Accepts an integer token count or a string effort level. + // See [effort.ValidNames] for the full list of accepted strings. + // Provider-specific mappings are in the effort package. + ThinkingBudget *ThinkingBudget `json:"thinking_budget,omitempty"` + // Routing defines rules for routing requests to different models. + // When routing is configured, this model becomes a rule-based router: + // - The provider/model fields define the fallback model + // - Each routing rule maps to a different model based on examples + Routing []RoutingRule `json:"routing,omitempty"` +} + +// Clone returns a deep copy of the ModelConfig. +func (m *ModelConfig) Clone() *ModelConfig { + if m == nil { + return nil + } + var c ModelConfig + types.CloneThroughJSON(m, &c) + // Preserve fields excluded from JSON serialization + c.Name = m.Name + c.DisplayModel = m.DisplayModel + return &c +} + +// DisplayOrModel returns DisplayModel if set (i.e., alias resolution preserved the original name), +// otherwise falls back to Model. +func (m *ModelConfig) DisplayOrModel() string { + return cmp.Or(m.DisplayModel, m.Model) +} + +// FlexibleModelConfig wraps ModelConfig to support both shorthand and full syntax. +// It can be unmarshaled from either: +// - A shorthand string: "provider/model" (e.g., "anthropic/claude-sonnet-4-5") +// - A full model definition with all options +type FlexibleModelConfig struct { + ModelConfig +} + +// UnmarshalYAML implements custom unmarshaling for flexible model config +func (f *FlexibleModelConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Try string shorthand first + var shorthand string + if err := unmarshal(&shorthand); err == nil && shorthand != "" { + parsed, parseErr := ParseModelRef(shorthand) + if parseErr != nil { + return fmt.Errorf("invalid model shorthand %q: expected format 'provider/model'", shorthand) + } + f.Provider = parsed.Provider + f.Model = parsed.Model + return nil + } + + // Try full model config + var cfg ModelConfig + if err := unmarshal(&cfg); err != nil { + return err + } + f.ModelConfig = cfg + return nil +} + +// MarshalYAML outputs shorthand format if only provider/model are set +func (f FlexibleModelConfig) MarshalYAML() (any, error) { + if f.isShorthandOnly() { + return f.Provider + "/" + f.Model, nil + } + return f.ModelConfig, nil +} + +// isShorthandOnly returns true if only provider and model are set +func (f *FlexibleModelConfig) isShorthandOnly() bool { + return f.Temperature == nil && + f.MaxTokens == nil && + f.TopP == nil && + f.FrequencyPenalty == nil && + f.PresencePenalty == nil && + f.BaseURL == "" && + f.ParallelToolCalls == nil && + f.TokenKey == "" && + len(f.ProviderOpts) == 0 && + f.TrackUsage == nil && + f.ThinkingBudget == nil && + len(f.Routing) == 0 +} + +// RoutingRule defines a single routing rule for model selection. +// Each rule maps example phrases to a target model. +type RoutingRule struct { + // Model is a reference to another model in the models section or an inline model spec (e.g., "openai/gpt-4o") + Model string `json:"model"` + // Examples are phrases that should trigger routing to this model + Examples []string `json:"examples"` +} + +type Metadata struct { + Author string `json:"author,omitempty"` + License string `json:"license,omitempty"` + Description string `json:"description,omitempty"` + Readme string `json:"readme,omitempty"` + Version string `json:"version,omitempty"` +} + +// Commands represents a set of named prompts for quick-starting conversations. +// It supports two YAML formats: +// +// commands: +// +// df: "check disk space" +// ls: "list files" +// +// or +// +// commands: +// - df: "check disk space" +// - ls: "list files" +// Commands YAML unmarshalling is implemented in pkg/config/types/commands.go + +// ScriptShellToolConfig represents a custom shell tool configuration +type ScriptShellToolConfig struct { + Cmd string `json:"cmd"` + Description string `json:"description"` + + // Args is directly passed as "properties" in the JSON schema + Args map[string]any `json:"args,omitempty"` + + // Required is directly passed as "required" in the JSON schema + Required []string `json:"required"` + + Env map[string]string `json:"env,omitempty"` + WorkingDir string `json:"working_dir,omitempty"` +} + +type APIToolConfig struct { + Instruction string `json:"instruction,omitempty"` + Name string `json:"name,omitempty"` + Required []string `json:"required,omitempty"` + Args map[string]any `json:"args,omitempty"` + Endpoint string `json:"endpoint,omitempty"` + Method string `json:"method,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + // OutputSchema optionally describes the API response as JSON Schema for MCP/Code Mode consumers; runtime still returns the raw string body. + OutputSchema map[string]any `json:"output_schema,omitempty"` +} + +// PostEditConfig represents a post-edit command configuration +type PostEditConfig struct { + Path string `json:"path"` + Cmd string `json:"cmd"` +} + +// Toolset represents a tool configuration +type Toolset struct { + Type string `json:"type,omitempty"` + Tools []string `json:"tools,omitempty"` + Instruction string `json:"instruction,omitempty"` + Toon string `json:"toon,omitempty"` + + // Model overrides the LLM used for the turn that processes tool results + // from this toolset, enabling per-toolset model routing. Value can be a + // model name from the models section or "provider/model" (e.g. "openai/gpt-4o-mini"). + Model string `json:"model,omitempty"` + + Defer DeferConfig `json:"defer" yaml:"defer,omitempty"` + + // For the `mcp` tool + Command string `json:"command,omitempty"` + Args []string `json:"args,omitempty"` + Ref string `json:"ref,omitempty"` + Remote Remote `json:"remote"` + Config any `json:"config,omitempty"` + + // For `mcp` and `lsp` tools - version/package reference for auto-installation. + // Format: "owner/repo" or "owner/repo@version" + // When empty and auto-install is enabled, docker agent auto-detects from the command name. + // Set to "false" or "off" to disable auto-install for this toolset. + Version string `json:"version,omitempty"` + + // For the `a2a` and `openapi` tools + Name string `json:"name,omitempty"` + URL string `json:"url,omitempty"` + Headers map[string]string `json:"headers,omitempty"` + + // For `shell`, `script`, `mcp` or `lsp` tools + Env map[string]string `json:"env,omitempty"` + + // For the `todo` tool + Shared bool `json:"shared,omitempty"` + + // For the `memory` and `tasks` tools + Path string `json:"path,omitempty"` + + // For the `script` tool + Shell map[string]ScriptShellToolConfig `json:"shell,omitempty"` + + // For the `filesystem` tool - post-edit commands + PostEdit []PostEditConfig `json:"post_edit,omitempty"` + + APIConfig APIToolConfig `json:"api_config"` + + // For the `filesystem` tool - VCS integration + IgnoreVCS *bool `json:"ignore_vcs,omitempty"` + + // For the `lsp` tool + FileTypes []string `json:"file_types,omitempty"` + + // For the `fetch` tool + Timeout int `json:"timeout,omitempty"` + + // For the `rag` tool + RAGConfig *RAGConfig `json:"rag_config,omitempty" yaml:"rag_config,omitempty"` + + // For the `model_picker` tool + Models []string `json:"models,omitempty"` +} + +func (t *Toolset) UnmarshalYAML(unmarshal func(any) error) error { + type alias Toolset + var tmp alias + if err := unmarshal(&tmp); err != nil { + return err + } + *t = Toolset(tmp) + return t.validate() +} + +type Remote struct { + URL string `json:"url"` + TransportType string `json:"transport_type,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +// DeferConfig represents the deferred loading configuration for a toolset. +// It can be either a boolean (true to defer all tools) or a slice of strings +// (list of tool names to defer). +type DeferConfig struct { //nolint:recvcheck // MarshalYAML must use value receiver for YAML slice encoding, UnmarshalYAML must use pointer + // DeferAll is true when all tools should be deferred + DeferAll bool `json:"-"` + // Tools is the list of specific tool names to defer (empty if DeferAll is true) + Tools []string `json:"-"` +} + +func (d DeferConfig) IsEmpty() bool { + return !d.DeferAll && len(d.Tools) == 0 +} + +func (d *DeferConfig) UnmarshalYAML(unmarshal func(any) error) error { + var b bool + if err := unmarshal(&b); err == nil { + d.DeferAll = b + d.Tools = nil + return nil + } + + var tools []string + if err := unmarshal(&tools); err == nil { + d.DeferAll = false + d.Tools = tools + return nil + } + + return nil +} + +func (d DeferConfig) MarshalYAML() (any, error) { + if d.DeferAll { + return true, nil + } + if len(d.Tools) == 0 { + // Return false for empty config - this will be omitted by yaml encoder + return false, nil + } + return d.Tools, nil +} + +// ThinkingBudget represents reasoning budget configuration. +// It accepts either a string effort level (see [effort.ValidNames]) or an +// integer token budget. +type ThinkingBudget struct { + // Effort stores string-based reasoning effort levels + Effort string `json:"effort,omitempty"` + // Tokens stores integer-based token budgets + Tokens int `json:"tokens,omitempty"` +} + +func (t *ThinkingBudget) UnmarshalYAML(unmarshal func(any) error) error { + // Try integer tokens first + var n int + if err := unmarshal(&n); err == nil { + *t = ThinkingBudget{Tokens: n} + return nil + } + + // Try string level + var s string + if err := unmarshal(&s); err == nil { + if !effort.IsValid(s) { + return fmt.Errorf("invalid thinking_budget effort %q: must be one of %s", s, effort.ValidNames()) + } + *t = ThinkingBudget{Effort: s} + return nil + } + + return nil +} + +// MarshalYAML implements custom marshaling to output simple string or int format +func (t ThinkingBudget) MarshalYAML() (any, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return t.Effort, nil + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return t.Tokens, nil +} + +// IsDisabled returns true if the thinking budget is explicitly disabled. +// A nil receiver is treated as "not configured" (not disabled). +// +// Disabled when: +// - Tokens == 0 with no Effort (thinking_budget: 0) +// - Effort == "none" (thinking_budget: none) +// +// NOT disabled when: +// - Tokens > 0 or Tokens == -1 (explicit token budget) +// - Effort is a real level like "medium" or "high" +// - Effort is "adaptive" +func (t *ThinkingBudget) IsDisabled() bool { + if t == nil { + return false + } + if t.Tokens == 0 && t.Effort == "" { + return true + } + return strings.EqualFold(t.Effort, "none") +} + +// IsAdaptive returns true if the thinking budget is set to adaptive mode. +// Adaptive thinking lets the model decide how much thinking to do. +// Matches both "adaptive" and "adaptive/" formats. +func (t *ThinkingBudget) IsAdaptive() bool { + if t == nil { + return false + } + norm := strings.ToLower(strings.TrimSpace(t.Effort)) + return norm == "adaptive" || strings.HasPrefix(norm, "adaptive/") +} + +// EffortLevel parses the Effort field into an [effort.Level]. +// Returns ("", false) when the budget is nil, uses token counts, or has an +// unrecognised effort string. +func (t *ThinkingBudget) EffortLevel() (effort.Level, bool) { + if t == nil { + return "", false + } + return effort.Parse(t.Effort) +} + +// AdaptiveEffort returns the effort level for adaptive thinking. +// For "adaptive" it returns the default ("high"). +// For "adaptive/" it returns the specified effort. +// Returns ("", false) if the budget is not adaptive. +func (t *ThinkingBudget) AdaptiveEffort() (string, bool) { + if !t.IsAdaptive() { + return "", false + } + norm := strings.ToLower(strings.TrimSpace(t.Effort)) + if after, ok := strings.CutPrefix(norm, "adaptive/"); ok && after != "" { + return after, true + } + return "high", true +} + +// EffortTokens maps a string effort level to a token budget for providers +// that only support token-based thinking (e.g. Bedrock Claude). +// Delegates to [effort.BedrockTokens]. +// +// Returns (tokens, true) when a mapping exists, or (0, false) when +// the budget uses an explicit token count or an unrecognised effort string. +func (t *ThinkingBudget) EffortTokens() (int, bool) { + l, ok := t.EffortLevel() + if !ok { + return 0, false + } + return effort.BedrockTokens(l) +} + +// MarshalJSON implements custom marshaling to output simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t ThinkingBudget) MarshalJSON() ([]byte, error) { + // If Effort string is set (non-empty), marshal as string + if t.Effort != "" { + return fmt.Appendf(nil, "%q", t.Effort), nil + } + + // Otherwise marshal as integer (includes 0, -1, and positive values) + return fmt.Appendf(nil, "%d", t.Tokens), nil +} + +// UnmarshalJSON implements custom unmarshaling to accept simple string or int format +// This ensures JSON and YAML have the same flattened format for consistency +func (t *ThinkingBudget) UnmarshalJSON(data []byte) error { + // Try integer tokens first + var n int + if err := json.Unmarshal(data, &n); err == nil { + *t = ThinkingBudget{Tokens: n} + return nil + } + + // Try string level + var s string + if err := json.Unmarshal(data, &s); err == nil { + if !effort.IsValid(s) { + return fmt.Errorf("invalid thinking_budget effort %q: must be one of %s", s, effort.ValidNames()) + } + *t = ThinkingBudget{Effort: s} + return nil + } + + return nil +} + +// StructuredOutput defines a JSON schema for structured output +type StructuredOutput struct { + // Name is the name of the response format + Name string `json:"name"` + // Description is optional description of the response format + Description string `json:"description,omitempty"` + // Schema is a JSON schema object defining the structure + Schema map[string]any `json:"schema"` + // Strict enables strict schema adherence (OpenAI only) + Strict bool `json:"strict,omitempty"` +} + +// RAGToolConfig represents tool-specific configuration for a RAG source +type RAGToolConfig struct { + Name string `json:"name,omitempty"` // Custom name for the tool (defaults to RAG source name if empty) + Description string `json:"description,omitempty"` // Tool description (what the tool does) + Instruction string `json:"instruction,omitempty"` // Tool instruction (how to use the tool effectively) +} + +// RAGConfig represents a RAG (Retrieval-Augmented Generation) configuration +// Uses a unified strategies array for flexible, extensible configuration +type RAGConfig struct { + Tool RAGToolConfig `json:"tool"` // Tool configuration + Docs []string `json:"docs,omitempty"` // Shared documents across all strategies + RespectVCS *bool `json:"respect_vcs,omitempty"` // Whether to respect VCS ignore files like .gitignore (default: true) + Strategies []RAGStrategyConfig `json:"strategies,omitempty"` // Array of strategy configurations + Results RAGResultsConfig `json:"results"` +} + +// GetRespectVCS returns whether VCS ignore files should be respected, defaulting to true +func (c *RAGConfig) GetRespectVCS() bool { + if c.RespectVCS == nil { + return true + } + return *c.RespectVCS +} + +// RAGStrategyConfig represents a single retrieval strategy configuration +// Strategy-specific fields are stored in Params (validated by strategy implementation) +type RAGStrategyConfig struct { //nolint:recvcheck // Marshal methods must use value receiver for YAML/JSON slice encoding, Unmarshal must use pointer + Type string `json:"type"` // Strategy type: "chunked-embeddings", "bm25", etc. + Docs []string `json:"docs,omitempty"` // Strategy-specific documents (augments shared docs) + Database RAGDatabaseConfig `json:"database"` // Database configuration + Chunking RAGChunkingConfig `json:"chunking"` // Chunking configuration + Limit int `json:"limit,omitempty"` // Max results from this strategy (for fusion input) + + // Strategy-specific parameters (arbitrary key-value pairs) + // Examples: + // - chunked-embeddings: embedding_model, similarity_metric, threshold, vector_dimensions + // - bm25: k1, b, threshold + Params map[string]any // Flattened into parent JSON +} + +// UnmarshalYAML implements custom unmarshaling to capture all extra fields into Params +// This allows strategies to have flexible, strategy-specific configuration parameters +// without requiring changes to the core config schema +func (s *RAGStrategyConfig) UnmarshalYAML(unmarshal func(any) error) error { + // First unmarshal into a map to capture everything + var raw map[string]any + if err := unmarshal(&raw); err != nil { + return err + } + + // Extract known fields + if t, ok := raw["type"].(string); ok { + s.Type = t + delete(raw, "type") + } + + if docs, ok := raw["docs"].([]any); ok { + s.Docs = make([]string, len(docs)) + for i, d := range docs { + if str, ok := d.(string); ok { + s.Docs[i] = str + } + } + delete(raw, "docs") + } + + if dbRaw, ok := raw["database"]; ok { + // Unmarshal database config using helper + var db RAGDatabaseConfig + unmarshalDatabaseConfig(dbRaw, &db) + s.Database = db + delete(raw, "database") + } + + if chunkRaw, ok := raw["chunking"]; ok { + var chunk RAGChunkingConfig + unmarshalChunkingConfig(chunkRaw, &chunk) + s.Chunking = chunk + delete(raw, "chunking") + } + + if limit, ok := raw["limit"].(int); ok { + s.Limit = limit + delete(raw, "limit") + } + + // Everything else goes into Params for strategy-specific configuration + s.Params = raw + + return nil +} + +// MarshalYAML implements custom marshaling to flatten Params into parent level +func (s RAGStrategyConfig) MarshalYAML() (any, error) { + result := s.buildFlattenedMap() + return result, nil +} + +// MarshalJSON implements custom marshaling to flatten Params into parent level +// This ensures JSON and YAML have the same flattened format for consistency +func (s RAGStrategyConfig) MarshalJSON() ([]byte, error) { + result := s.buildFlattenedMap() + return json.Marshal(result) +} + +// UnmarshalJSON implements custom unmarshaling to capture all extra fields into Params +// This ensures JSON and YAML have the same flattened format for consistency +func (s *RAGStrategyConfig) UnmarshalJSON(data []byte) error { + // First unmarshal into a map to capture everything + var raw map[string]any + if err := json.Unmarshal(data, &raw); err != nil { + return err + } + + // Extract known fields + if t, ok := raw["type"].(string); ok { + s.Type = t + delete(raw, "type") + } + + if docs, ok := raw["docs"].([]any); ok { + s.Docs = make([]string, len(docs)) + for i, d := range docs { + if str, ok := d.(string); ok { + s.Docs[i] = str + } + } + delete(raw, "docs") + } + + if dbRaw, ok := raw["database"]; ok { + if dbStr, ok := dbRaw.(string); ok { + var db RAGDatabaseConfig + db.value = dbStr + s.Database = db + } + delete(raw, "database") + } + + if chunkRaw, ok := raw["chunking"]; ok { + // Re-marshal and unmarshal chunking config + chunkBytes, _ := json.Marshal(chunkRaw) + var chunk RAGChunkingConfig + if err := json.Unmarshal(chunkBytes, &chunk); err == nil { + s.Chunking = chunk + } + delete(raw, "chunking") + } + + if limit, ok := raw["limit"].(float64); ok { + s.Limit = int(limit) + delete(raw, "limit") + } + + // Everything else goes into Params for strategy-specific configuration + s.Params = raw + + return nil +} + +// buildFlattenedMap creates a flattened map representation for marshaling +// Used by both MarshalYAML and MarshalJSON to ensure consistent format +func (s RAGStrategyConfig) buildFlattenedMap() map[string]any { + result := make(map[string]any) + + if s.Type != "" { + result["type"] = s.Type + } + if len(s.Docs) > 0 { + result["docs"] = s.Docs + } + if !s.Database.IsEmpty() { + dbStr, _ := s.Database.AsString() + result["database"] = dbStr + } + // Only include chunking if any fields are set + if s.Chunking.Size > 0 || s.Chunking.Overlap > 0 || s.Chunking.RespectWordBoundaries { + result["chunking"] = s.Chunking + } + if s.Limit > 0 { + result["limit"] = s.Limit + } + + // Flatten Params into the same level + maps.Copy(result, s.Params) + + return result +} + +// unmarshalDatabaseConfig handles DatabaseConfig unmarshaling from raw YAML data. +// For RAG strategies, the database configuration is intentionally simple: +// a single string value under the `database` key that points to the SQLite +// database file on disk. TODO(krissetto): eventually support more db types +func unmarshalDatabaseConfig(src any, dst *RAGDatabaseConfig) { + s, ok := src.(string) + if !ok { + return + } + + dst.value = s +} + +// unmarshalChunkingConfig handles ChunkingConfig unmarshaling from raw YAML data +func unmarshalChunkingConfig(src any, dst *RAGChunkingConfig) { + m, ok := src.(map[string]any) + if !ok { + return + } + + // Handle size - try various numeric types that YAML might produce + if size, ok := m["size"]; ok { + dst.Size = coerceToInt(size) + } + + // Handle overlap - try various numeric types that YAML might produce + if overlap, ok := m["overlap"]; ok { + dst.Overlap = coerceToInt(overlap) + } + + // Handle respect_word_boundaries - YAML should give us a bool + if rwb, ok := m["respect_word_boundaries"]; ok { + if val, ok := rwb.(bool); ok { + dst.RespectWordBoundaries = val + } + } + + // Handle code_aware - YAML should give us a bool + if ca, ok := m["code_aware"]; ok { + if val, ok := ca.(bool); ok { + dst.CodeAware = val + } + } +} + +// coerceToInt converts various numeric types to int +func coerceToInt(v any) int { + switch val := v.(type) { + case int: + return val + case int64: + return int(val) + case uint64: + return int(val) + case float64: + return int(val) + default: + return 0 + } +} + +// RAGDatabaseConfig represents database configuration for RAG strategies. +// Currently it only supports a single string value which is interpreted as +// the path to a SQLite database file. +type RAGDatabaseConfig struct { + value any // nil (unset) or string path +} + +// UnmarshalYAML implements custom unmarshaling for DatabaseConfig +func (d *RAGDatabaseConfig) UnmarshalYAML(unmarshal func(any) error) error { + var str string + if err := unmarshal(&str); err == nil { + d.value = str + return nil + } + + return errors.New("database must be a string path to a sqlite database") +} + +// AsString returns the database config as a connection string +// For simple string configs, returns as-is +// For structured configs, builds connection string based on type +func (d *RAGDatabaseConfig) AsString() (string, error) { + if d.value == nil { + return "", nil + } + + if str, ok := d.value.(string); ok { + return str, nil + } + + return "", errors.New("invalid database configuration: expected string path") +} + +// IsEmpty returns true if no database is configured +func (d *RAGDatabaseConfig) IsEmpty() bool { + return d.value == nil +} + +// RAGChunkingConfig represents text chunking configuration +type RAGChunkingConfig struct { + Size int `json:"size,omitempty"` + Overlap int `json:"overlap,omitempty"` + RespectWordBoundaries bool `json:"respect_word_boundaries,omitempty"` + // CodeAware enables code-aware chunking for source files. When true, the + // chunking strategy uses tree-sitter for AST-based chunking, producing + // semantically aligned chunks (e.g., whole functions). Falls back to + // plain text chunking for unsupported languages. + CodeAware bool `json:"code_aware,omitempty"` +} + +// UnmarshalYAML implements custom unmarshaling to apply sensible defaults for chunking +func (c *RAGChunkingConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Use a struct with pointer to distinguish "not set" from "explicitly set to false" + var raw struct { + Size int `yaml:"size"` + Overlap int `yaml:"overlap"` + RespectWordBoundaries *bool `yaml:"respect_word_boundaries"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + c.Size = raw.Size + c.Overlap = raw.Overlap + + // Apply default of true for RespectWordBoundaries if not explicitly set + if raw.RespectWordBoundaries != nil { + c.RespectWordBoundaries = *raw.RespectWordBoundaries + } else { + c.RespectWordBoundaries = true + } + + return nil +} + +// RAGResultsConfig represents result post-processing configuration (common across strategies) +type RAGResultsConfig struct { + Limit int `json:"limit,omitempty"` // Maximum number of results to return (top K) + Fusion *RAGFusionConfig `json:"fusion,omitempty"` // How to combine results from multiple strategies + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` // Optional reranking configuration + Deduplicate bool `json:"deduplicate,omitempty"` // Remove duplicate documents across strategies + IncludeScore bool `json:"include_score,omitempty"` // Include relevance scores in results + ReturnFullContent bool `json:"return_full_content,omitempty"` // Return full document content instead of just matched chunks +} + +// RAGRerankingConfig represents reranking configuration +type RAGRerankingConfig struct { + Model string `json:"model"` // Model reference for reranking (e.g., "hf.co/ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF") + TopK int `json:"top_k,omitempty"` // Optional: only rerank top K results (0 = rerank all) + Threshold float64 `json:"threshold,omitempty"` // Optional: minimum score threshold after reranking (default: 0.5) + Criteria string `json:"criteria,omitempty"` // Optional: domain-specific relevance criteria to guide scoring +} + +// UnmarshalYAML implements custom unmarshaling to apply sensible defaults for reranking +func (r *RAGRerankingConfig) UnmarshalYAML(unmarshal func(any) error) error { + // Use a struct with pointer to distinguish "not set" from "explicitly set to 0" + var raw struct { + Model string `yaml:"model"` + TopK int `yaml:"top_k"` + Threshold *float64 `yaml:"threshold"` + Criteria string `yaml:"criteria"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + r.Model = raw.Model + r.TopK = raw.TopK + r.Criteria = raw.Criteria + + // Apply default threshold of 0.5 if not explicitly set + // This filters documents with negative logits (sigmoid < 0.5 = not relevant) + if raw.Threshold != nil { + r.Threshold = *raw.Threshold + } else { + r.Threshold = 0.5 + } + + return nil +} + +// defaultRAGResultsConfig returns the default results configuration +func defaultRAGResultsConfig() RAGResultsConfig { + return RAGResultsConfig{ + Limit: 15, + Deduplicate: true, + IncludeScore: false, + ReturnFullContent: false, + } +} + +// UnmarshalYAML implements custom unmarshaling so we can apply sensible defaults +func (r *RAGResultsConfig) UnmarshalYAML(unmarshal func(any) error) error { + var raw struct { + Limit int `json:"limit,omitempty"` + Fusion *RAGFusionConfig `json:"fusion,omitempty"` + Reranking *RAGRerankingConfig `json:"reranking,omitempty"` + Deduplicate *bool `json:"deduplicate,omitempty"` + IncludeScore *bool `json:"include_score,omitempty"` + ReturnFullContent *bool `json:"return_full_content,omitempty"` + } + + if err := unmarshal(&raw); err != nil { + return err + } + + // Start from defaults and then overwrite with any provided values. + def := defaultRAGResultsConfig() + *r = def + + if raw.Limit != 0 { + r.Limit = raw.Limit + } + r.Fusion = raw.Fusion + r.Reranking = raw.Reranking + + if raw.Deduplicate != nil { + r.Deduplicate = *raw.Deduplicate + } + if raw.IncludeScore != nil { + r.IncludeScore = *raw.IncludeScore + } + if raw.ReturnFullContent != nil { + r.ReturnFullContent = *raw.ReturnFullContent + } + + return nil +} + +// UnmarshalYAML for RAGConfig ensures that the Results field is always +// initialized with defaults, even when the `results` block is omitted. +func (c *RAGConfig) UnmarshalYAML(unmarshal func(any) error) error { + type alias RAGConfig + tmp := alias{ + Results: defaultRAGResultsConfig(), + } + if err := unmarshal(&tmp); err != nil { + return err + } + *c = RAGConfig(tmp) + return nil +} + +// RAGFusionConfig represents configuration for combining multi-strategy results +type RAGFusionConfig struct { + Strategy string `json:"strategy,omitempty"` // Fusion strategy: "rrf" (Reciprocal Rank Fusion), "weighted", "max" + K int `json:"k,omitempty"` // RRF parameter k (default: 60) + Weights map[string]float64 `json:"weights,omitempty"` // Strategy weights for weighted fusion +} + +// PermissionsConfig represents tool permission configuration. +// Allow/Ask/Deny model. This controls tool call approval behavior: +// - Allow: Tools matching these patterns are auto-approved (like --yolo for specific tools) +// - Ask: Tools matching these patterns always require user approval, even if the tool is read-only +// - Deny: Tools matching these patterns are always rejected, even with --yolo +// +// Patterns support glob-style matching (e.g., "shell", "read_*", "mcp:github:*") +// The evaluation order is: Deny (checked first), then Allow, then Ask (explicit), then default +// (read-only tools auto-approved, others ask) +type PermissionsConfig struct { + // Allow lists tool name patterns that are auto-approved without user confirmation + Allow []string `json:"allow,omitempty"` + // Ask lists tool name patterns that always require user confirmation, + // even for tools that are normally auto-approved (e.g. read-only tools) + Ask []string `json:"ask,omitempty"` + // Deny lists tool name patterns that are always rejected + Deny []string `json:"deny,omitempty"` +} + +// HooksConfig represents the hooks configuration for an agent. +// Hooks allow running shell commands at various points in the agent lifecycle. +type HooksConfig struct { + // PreToolUse hooks run before tool execution + PreToolUse []HookMatcherConfig `json:"pre_tool_use,omitempty" yaml:"pre_tool_use,omitempty"` + + // PostToolUse hooks run after tool execution + PostToolUse []HookMatcherConfig `json:"post_tool_use,omitempty" yaml:"post_tool_use,omitempty"` + + // SessionStart hooks run when a session begins + SessionStart []HookDefinition `json:"session_start,omitempty" yaml:"session_start,omitempty"` + + // SessionEnd hooks run when a session ends + SessionEnd []HookDefinition `json:"session_end,omitempty" yaml:"session_end,omitempty"` + + // OnUserInput hooks run when the agent needs user input + OnUserInput []HookDefinition `json:"on_user_input,omitempty" yaml:"on_user_input,omitempty"` + + // Stop hooks run when the model finishes responding and is about to hand control back to the user + Stop []HookDefinition `json:"stop,omitempty" yaml:"stop,omitempty"` + + // Notification hooks run when the agent sends a notification (error, warning) to the user + Notification []HookDefinition `json:"notification,omitempty" yaml:"notification,omitempty"` +} + +// IsEmpty returns true if no hooks are configured +func (h *HooksConfig) IsEmpty() bool { + if h == nil { + return true + } + return len(h.PreToolUse) == 0 && + len(h.PostToolUse) == 0 && + len(h.SessionStart) == 0 && + len(h.SessionEnd) == 0 && + len(h.OnUserInput) == 0 && + len(h.Stop) == 0 && + len(h.Notification) == 0 +} + +// HookMatcherConfig represents a hook matcher with its hooks. +// Used for tool-related hooks (PreToolUse, PostToolUse). +type HookMatcherConfig struct { + // Matcher is a regex pattern to match tool names (e.g., "shell|edit_file") + // Use "*" to match all tools. Case-sensitive. + Matcher string `json:"matcher,omitempty" yaml:"matcher,omitempty"` + + // Hooks are the hooks to execute when the matcher matches + Hooks []HookDefinition `json:"hooks" yaml:"hooks"` +} + +// HookDefinition represents a single hook configuration +type HookDefinition struct { + // Type specifies the hook type (currently only "command" is supported) + Type string `json:"type" yaml:"type"` + + // Command is the shell command to execute + Command string `json:"command,omitempty" yaml:"command,omitempty"` + + // Timeout is the execution timeout in seconds (default: 60) + Timeout int `json:"timeout,omitempty" yaml:"timeout,omitempty"` +} + +// validate validates the HooksConfig +func (h *HooksConfig) validate() error { + // Validate PreToolUse matchers + for i, m := range h.PreToolUse { + if err := m.validate("pre_tool_use", i); err != nil { + return err + } + } + + // Validate PostToolUse matchers + for i, m := range h.PostToolUse { + if err := m.validate("post_tool_use", i); err != nil { + return err + } + } + + // Validate SessionStart hooks + for i, hook := range h.SessionStart { + if err := hook.validate("session_start", i); err != nil { + return err + } + } + + // Validate SessionEnd hooks + for i, hook := range h.SessionEnd { + if err := hook.validate("session_end", i); err != nil { + return err + } + } + + // Validate OnUserInput hooks + for i, hook := range h.OnUserInput { + if err := hook.validate("on_user_input", i); err != nil { + return err + } + } + + // Validate Stop hooks + for i, hook := range h.Stop { + if err := hook.validate("stop", i); err != nil { + return err + } + } + + // Validate Notification hooks + for i, hook := range h.Notification { + if err := hook.validate("notification", i); err != nil { + return err + } + } + + return nil +} + +// validate validates a HookMatcherConfig +func (m *HookMatcherConfig) validate(eventType string, index int) error { + if len(m.Hooks) == 0 { + return fmt.Errorf("hooks.%s[%d]: at least one hook is required", eventType, index) + } + + for i, hook := range m.Hooks { + if err := hook.validate(fmt.Sprintf("%s[%d].hooks", eventType, index), i); err != nil { + return err + } + } + + return nil +} + +// validate validates a HookDefinition +func (h *HookDefinition) validate(prefix string, index int) error { + if h.Type == "" { + return fmt.Errorf("hooks.%s[%d]: type is required", prefix, index) + } + + if h.Type != "command" { + return fmt.Errorf("hooks.%s[%d]: unsupported hook type '%s' (only 'command' is supported)", prefix, index, h.Type) + } + + if h.Command == "" { + return fmt.Errorf("hooks.%s[%d]: command is required for command hooks", prefix, index) + } + + return nil +} diff --git a/pkg/config/v7/types_test.go b/pkg/config/v7/types_test.go new file mode 100644 index 000000000..2dbf67edb --- /dev/null +++ b/pkg/config/v7/types_test.go @@ -0,0 +1,419 @@ +package latest + +import ( + "encoding/json" + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/require" + + "github.com/docker/docker-agent/pkg/config/types" +) + +func TestCommandsUnmarshal_Map(t *testing.T) { + var c types.Commands + input := []byte(` +df: "check disk" +ls: "list files" +`) + err := yaml.Unmarshal(input, &c) + require.NoError(t, err) + require.Equal(t, "check disk", c["df"].Instruction) + require.Equal(t, "list files", c["ls"].Instruction) +} + +func TestCommandsUnmarshal_List(t *testing.T) { + var c types.Commands + input := []byte(` +- df: "check disk" +- ls: "list files" +`) + err := yaml.Unmarshal(input, &c) + require.NoError(t, err) + require.Equal(t, "check disk", c["df"].Instruction) + require.Equal(t, "list files", c["ls"].Instruction) +} + +func TestThinkingBudget_MarshalUnmarshal_String(t *testing.T) { + t.Parallel() + + // Test string effort level + input := []byte(`thinking_budget: minimal`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Equal(t, "minimal", config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: minimal\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Integer(t *testing.T) { + t.Parallel() + + // Test integer token budget + input := []byte(`thinking_budget: 8192`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 8192, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 8192\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_NegativeInteger(t *testing.T) { + t.Parallel() + + // Test negative integer token budget (e.g., -1 for Gemini dynamic thinking) + input := []byte(`thinking_budget: -1`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, -1, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: -1\n", string(output)) +} + +func TestThinkingBudget_MarshalUnmarshal_Zero(t *testing.T) { + t.Parallel() + + // Test zero token budget (e.g., 0 for Gemini no thinking) + input := []byte(`thinking_budget: 0`) + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + + // Unmarshal + err := yaml.Unmarshal(input, &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Empty(t, config.ThinkingBudget.Effort) + require.Equal(t, 0, config.ThinkingBudget.Tokens) + + // Marshal back + output, err := yaml.Marshal(config) + require.NoError(t, err) + require.Equal(t, "thinking_budget: 0\n", string(output)) +} + +func TestThinkingBudget_IsDisabled(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + b *ThinkingBudget + want bool + }{ + {"nil", nil, false}, + {"zero tokens", &ThinkingBudget{Tokens: 0}, true}, + {"none effort", &ThinkingBudget{Effort: "none"}, true}, + {"positive tokens", &ThinkingBudget{Tokens: 8192}, false}, + {"medium effort", &ThinkingBudget{Effort: "medium"}, false}, + {"adaptive effort", &ThinkingBudget{Effort: "adaptive"}, false}, + {"negative tokens (dynamic)", &ThinkingBudget{Tokens: -1}, false}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, tt.b.IsDisabled()) + }) + } +} + +func TestThinkingBudget_UnmarshalYAML_InvalidEffort(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + input string + }{ + {"typo", `thinking_budget: adaptative`}, + {"invalid adaptive effort", `thinking_budget: adaptive/ultra`}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + err := yaml.Unmarshal([]byte(tt.input), &config) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid thinking_budget effort") + }) + } +} + +func TestThinkingBudget_UnmarshalYAML_AdaptiveWithEffort(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + input string + wantEffort string + }{ + {"adaptive", `thinking_budget: adaptive`, "adaptive"}, + {"adaptive/low", `thinking_budget: adaptive/low`, "adaptive/low"}, + {"adaptive/medium", `thinking_budget: adaptive/medium`, "adaptive/medium"}, + {"adaptive/high", `thinking_budget: adaptive/high`, "adaptive/high"}, + {"adaptive/max", `thinking_budget: adaptive/max`, "adaptive/max"}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var config struct { + ThinkingBudget *ThinkingBudget `yaml:"thinking_budget"` + } + err := yaml.Unmarshal([]byte(tt.input), &config) + require.NoError(t, err) + require.NotNil(t, config.ThinkingBudget) + require.Equal(t, tt.wantEffort, config.ThinkingBudget.Effort) + require.True(t, config.ThinkingBudget.IsAdaptive()) + }) + } +} + +func TestThinkingBudget_UnmarshalJSON_InvalidEffort(t *testing.T) { + t.Parallel() + + data := []byte(`{"thinking_budget": "adaptative"}`) + var config struct { + ThinkingBudget *ThinkingBudget `json:"thinking_budget"` + } + + err := json.Unmarshal(data, &config) + require.Error(t, err) + require.Contains(t, err.Error(), `invalid thinking_budget effort "adaptative"`) +} + +func TestThinkingBudget_IsAdaptive(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + b *ThinkingBudget + want bool + }{ + {"nil", nil, false}, + {"adaptive", &ThinkingBudget{Effort: "adaptive"}, true}, + {"adaptive/high", &ThinkingBudget{Effort: "adaptive/high"}, true}, + {"adaptive/low", &ThinkingBudget{Effort: "adaptive/low"}, true}, + {"adaptive/medium", &ThinkingBudget{Effort: "adaptive/medium"}, true}, + {"adaptive/max", &ThinkingBudget{Effort: "adaptive/max"}, true}, + {"medium", &ThinkingBudget{Effort: "medium"}, false}, + {"tokens", &ThinkingBudget{Tokens: 8192}, false}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, tt.b.IsAdaptive()) + }) + } +} + +func TestThinkingBudget_AdaptiveEffort(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + b *ThinkingBudget + wantEffort string + wantOK bool + }{ + {"nil", nil, "", false}, + {"adaptive defaults to high", &ThinkingBudget{Effort: "adaptive"}, "high", true}, + {"adaptive/low", &ThinkingBudget{Effort: "adaptive/low"}, "low", true}, + {"adaptive/medium", &ThinkingBudget{Effort: "adaptive/medium"}, "medium", true}, + {"adaptive/high", &ThinkingBudget{Effort: "adaptive/high"}, "high", true}, + {"adaptive/max", &ThinkingBudget{Effort: "adaptive/max"}, "max", true}, + {"not adaptive", &ThinkingBudget{Effort: "medium"}, "", false}, + {"tokens", &ThinkingBudget{Tokens: 8192}, "", false}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + effort, ok := tt.b.AdaptiveEffort() + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantEffort, effort) + }) + } +} + +func TestThinkingBudget_EffortTokens(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + b *ThinkingBudget + wantTokens int + wantOK bool + }{ + {"nil", nil, 0, false}, + {"minimal", &ThinkingBudget{Effort: "minimal"}, 1024, true}, + {"low", &ThinkingBudget{Effort: "low"}, 2048, true}, + {"medium", &ThinkingBudget{Effort: "medium"}, 8192, true}, + {"high", &ThinkingBudget{Effort: "high"}, 16384, true}, + {"adaptive", &ThinkingBudget{Effort: "adaptive"}, 0, false}, + {"none", &ThinkingBudget{Effort: "none"}, 0, false}, + {"explicit tokens", &ThinkingBudget{Tokens: 4096}, 0, false}, + {"empty effort", &ThinkingBudget{}, 0, false}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tokens, ok := tt.b.EffortTokens() + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantTokens, tokens) + }) + } +} + +func TestAgents_UnmarshalYAML_RejectsUnknownFields(t *testing.T) { + t.Parallel() + + // "instructions" (plural) is not a valid field; the correct field is "instruction" (singular). + // Agents.UnmarshalYAML must reject it so that typos don't silently drop config. + input := []byte(`version: "5" +agents: + root: + model: openai/gpt-4o + instructions: "You are a helpful assistant." +`) + + _, err := parse(input) + require.Error(t, err) + require.Contains(t, err.Error(), "instructions") +} + +func TestAgents_UnmarshalYAML_AcceptsValidConfig(t *testing.T) { + t.Parallel() + + input := []byte(`version: "5" +agents: + root: + model: openai/gpt-4o + instruction: "You are a helpful assistant." +`) + + cfg, err := parse(input) + require.NoError(t, err) + require.Len(t, cfg.Agents, 1) + require.Equal(t, "root", cfg.Agents[0].Name) + require.Equal(t, "You are a helpful assistant.", cfg.Agents[0].Instruction) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_FlattenedParams(t *testing.T) { + t.Parallel() + + // Test that params are flattened during unmarshal and remain flattened after marshal + input := []byte(`type: chunked-embeddings +model: embeddinggemma +database: ./rag/test.db +threshold: 0.5 +vector_dimensions: 768 +`) + + var strategy RAGStrategyConfig + + // Unmarshal + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + require.Equal(t, "chunked-embeddings", strategy.Type) + require.Equal(t, "./rag/test.db", mustGetDBString(t, strategy.Database)) + require.NotNil(t, strategy.Params) + require.Equal(t, "embeddinggemma", strategy.Params["model"]) + require.InEpsilon(t, 0.5, strategy.Params["threshold"], 0.001) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + require.InEpsilon(t, float64(768), toFloat64(strategy.Params["vector_dimensions"]), 0.001) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Verify it's still flattened (no "params:" key) + outputStr := string(output) + require.Contains(t, outputStr, "type: chunked-embeddings") + require.Contains(t, outputStr, "model: embeddinggemma") + require.Contains(t, outputStr, "threshold: 0.5") + require.Contains(t, outputStr, "vector_dimensions: 768") + require.NotContains(t, outputStr, "params:") + + // Unmarshal again to verify round-trip + var strategy2 RAGStrategyConfig + err = yaml.Unmarshal(output, &strategy2) + require.NoError(t, err) + require.Equal(t, strategy.Type, strategy2.Type) + require.Equal(t, strategy.Params["model"], strategy2.Params["model"]) + require.Equal(t, strategy.Params["threshold"], strategy2.Params["threshold"]) + // YAML may unmarshal numbers as different numeric types (int, uint64, float64) + // Just verify the numeric value is correct + require.InEpsilon(t, float64(768), toFloat64(strategy2.Params["vector_dimensions"]), 0.001) +} + +func TestRAGStrategyConfig_MarshalUnmarshal_WithDatabase(t *testing.T) { + t.Parallel() + + input := []byte(`type: chunked-embeddings +database: ./test.db +model: test-model +`) + + var strategy RAGStrategyConfig + err := yaml.Unmarshal(input, &strategy) + require.NoError(t, err) + + // Marshal back + output, err := yaml.Marshal(strategy) + require.NoError(t, err) + + // Should contain database as a simple string, not nested with sub-fields + outputStr := string(output) + require.Contains(t, outputStr, "database: ./test.db") + require.NotContains(t, outputStr, " value:") // Should not be nested with internal fields + require.Contains(t, outputStr, "model: test-model") + require.NotContains(t, outputStr, "params:") // Should be flattened +} + +func mustGetDBString(t *testing.T, db RAGDatabaseConfig) string { + t.Helper() + str, err := db.AsString() + require.NoError(t, err) + return str +} + +// toFloat64 converts various numeric types to float64 for comparison +func toFloat64(v any) float64 { + switch val := v.(type) { + case int: + return float64(val) + case int64: + return float64(val) + case uint64: + return float64(val) + case float64: + return val + case float32: + return float64(val) + default: + return 0 + } +} diff --git a/pkg/config/v7/validate.go b/pkg/config/v7/validate.go new file mode 100644 index 000000000..6d2131418 --- /dev/null +++ b/pkg/config/v7/validate.go @@ -0,0 +1,168 @@ +package latest + +import ( + "errors" +) + +func (t *Config) UnmarshalYAML(unmarshal func(any) error) error { + type alias Config + var tmp alias + if err := unmarshal(&tmp); err != nil { + return err + } + *t = Config(tmp) + return t.validate() +} + +func (t *Config) validate() error { + for i := range t.Agents { + agent := &t.Agents[i] + + // Validate fallback config + if err := agent.validateFallback(); err != nil { + return err + } + + for j := range agent.Toolsets { + if err := agent.Toolsets[j].validate(); err != nil { + return err + } + } + if agent.Hooks != nil { + if err := agent.Hooks.validate(); err != nil { + return err + } + } + } + + return nil +} + +// validateFallback validates the fallback configuration for an agent +func (a *AgentConfig) validateFallback() error { + if a.Fallback == nil { + return nil + } + + // -1 is allowed as a special value meaning "explicitly no retries" + if a.Fallback.Retries < -1 { + return errors.New("fallback.retries must be >= -1 (use -1 for no retries, 0 for default)") + } + if a.Fallback.Cooldown.Duration < 0 { + return errors.New("fallback.cooldown must be non-negative") + } + + return nil +} + +func (t *Toolset) validate() error { + // Attributes used on the wrong toolset type. + if len(t.Shell) > 0 && t.Type != "script" { + return errors.New("shell can only be used with type 'script'") + } + if t.Path != "" && t.Type != "memory" && t.Type != "tasks" { + return errors.New("path can only be used with type 'memory' or 'tasks'") + } + if len(t.PostEdit) > 0 && t.Type != "filesystem" { + return errors.New("post_edit can only be used with type 'filesystem'") + } + if t.IgnoreVCS != nil && t.Type != "filesystem" { + return errors.New("ignore_vcs can only be used with type 'filesystem'") + } + if len(t.Env) > 0 && (t.Type != "shell" && t.Type != "script" && t.Type != "mcp" && t.Type != "lsp") { + return errors.New("env can only be used with type 'shell', 'script', 'mcp' or 'lsp'") + } + if len(t.FileTypes) > 0 && t.Type != "lsp" { + return errors.New("file_types can only be used with type 'lsp'") + } + if len(t.Models) > 0 && t.Type != "model_picker" { + return errors.New("models can only be used with type 'model_picker'") + } + if t.Shared && t.Type != "todo" { + return errors.New("shared can only be used with type 'todo'") + } + if t.Version != "" && t.Type != "mcp" && t.Type != "lsp" { + return errors.New("version can only be used with type 'mcp' or 'lsp'") + } + if t.Command != "" && t.Type != "mcp" && t.Type != "lsp" { + return errors.New("command can only be used with type 'mcp' or 'lsp'") + } + if len(t.Args) > 0 && t.Type != "mcp" && t.Type != "lsp" { + return errors.New("args can only be used with type 'mcp' or 'lsp'") + } + if t.Ref != "" && t.Type != "mcp" && t.Type != "rag" { + return errors.New("ref can only be used with type 'mcp' or 'rag'") + } + if (t.Remote.URL != "" || t.Remote.TransportType != "") && t.Type != "mcp" { + return errors.New("remote can only be used with type 'mcp'") + } + if (len(t.Remote.Headers) > 0) && (t.Type != "mcp" && t.Type != "a2a") { + return errors.New("remote headers can only be used with type 'mcp' or 'a2a'") + } + if len(t.Headers) > 0 && t.Type != "openapi" && t.Type != "a2a" { + return errors.New("headers can only be used with type 'openapi' or 'a2a'") + } + if t.Config != nil && t.Type != "mcp" { + return errors.New("config can only be used with type 'mcp'") + } + if t.URL != "" && t.Type != "a2a" && t.Type != "openapi" { + return errors.New("url can only be used with type 'a2a' or 'openapi'") + } + if t.Name != "" && (t.Type != "mcp" && t.Type != "a2a") { + return errors.New("name can only be used with type 'mcp' or 'a2a'") + } + if t.RAGConfig != nil && t.Type != "rag" { + return errors.New("rag_config can only be used with type 'rag'") + } + + switch t.Type { + case "shell": + // no additional validation needed + case "memory": + // path is optional; defaults to ~/.cagent/memory//memory.db + case "tasks": + // path defaults to ./tasks.json if not set + case "mcp": + count := 0 + if t.Command != "" { + count++ + } + if t.Remote.URL != "" { + count++ + } + if t.Ref != "" { + count++ + } + if count == 0 { + return errors.New("either command, remote or ref must be set") + } + if count > 1 { + return errors.New("either command, remote or ref must be set, but only one of those") + } + case "a2a": + if t.URL == "" { + return errors.New("a2a toolset requires a url to be set") + } + case "lsp": + if t.Command == "" { + return errors.New("lsp toolset requires a command to be set") + } + case "openapi": + if t.URL == "" { + return errors.New("openapi toolset requires a url to be set") + } + case "model_picker": + if len(t.Models) == 0 { + return errors.New("model_picker toolset requires at least one model in the 'models' list") + } + case "rag": + // rag toolset requires either a ref or inline rag_config + if t.Ref == "" && t.RAGConfig == nil { + return errors.New("rag toolset requires either ref or rag_config") + } + case "background_agents": + // no additional validation needed + } + + return nil +} diff --git a/pkg/config/v7/validate_test.go b/pkg/config/v7/validate_test.go new file mode 100644 index 000000000..6c90b9be7 --- /dev/null +++ b/pkg/config/v7/validate_test.go @@ -0,0 +1,117 @@ +package latest + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/stretchr/testify/require" +) + +func TestToolset_Validate_LSP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + config string + wantErr string + }{ + { + name: "valid lsp with command", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls +`, + wantErr: "", + }, + { + name: "lsp missing command", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp +`, + wantErr: "lsp toolset requires a command to be set", + }, + { + name: "lsp with args", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls + args: + - -remote=auto +`, + wantErr: "", + }, + { + name: "lsp with env", + config: ` +version: "3" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls + env: + GOFLAGS: "-mod=vendor" +`, + wantErr: "", + }, + { + name: "lsp with file_types", + config: ` +version: "5" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: lsp + command: gopls + file_types: [".go", ".mod"] +`, + wantErr: "", + }, + { + name: "file_types on non-lsp toolset", + config: ` +version: "5" +agents: + root: + model: "openai/gpt-4" + toolsets: + - type: shell + file_types: [".go"] +`, + wantErr: "file_types can only be used with type 'lsp'", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var cfg Config + err := yaml.Unmarshal([]byte(tt.config), &cfg) + + if tt.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tt.wantErr) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/pkg/config/versions.go b/pkg/config/versions.go index 1bb69f75f..846ef5037 100644 --- a/pkg/config/versions.go +++ b/pkg/config/versions.go @@ -9,6 +9,7 @@ import ( v4 "github.com/docker/docker-agent/pkg/config/v4" v5 "github.com/docker/docker-agent/pkg/config/v5" v6 "github.com/docker/docker-agent/pkg/config/v6" + v7 "github.com/docker/docker-agent/pkg/config/v7" ) func versions() (map[string]func([]byte) (any, error), []func(any, []byte) (any, error)) { @@ -22,6 +23,7 @@ func versions() (map[string]func([]byte) (any, error), []func(any, []byte) (any, v4.Register(parsers, &upgraders) v5.Register(parsers, &upgraders) v6.Register(parsers, &upgraders) + v7.Register(parsers, &upgraders) latest.Register(parsers, &upgraders) return parsers, upgraders