From ab33aab2eec50105194c9440ba701f66c01557e2 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Thu, 12 Mar 2026 11:25:58 +0000 Subject: [PATCH] Add explicit OpenAI API mode selection --- providers/azure/api_mode_test.go | 66 ++++++++ providers/azure/azure.go | 21 ++- providers/openai/api_mode_test.go | 165 +++++++++++++++++++ providers/openai/language_model_hooks.go | 9 +- providers/openai/openai.go | 89 ++++++++-- providers/openai/provider_options_guard.go | 49 ++++++ providers/openai/responses_language_model.go | 32 ++-- providers/openai/responses_options.go | 2 +- 8 files changed, 402 insertions(+), 31 deletions(-) create mode 100644 providers/azure/api_mode_test.go create mode 100644 providers/openai/api_mode_test.go create mode 100644 providers/openai/provider_options_guard.go diff --git a/providers/azure/api_mode_test.go b/providers/azure/api_mode_test.go new file mode 100644 index 000000000..37d83fef7 --- /dev/null +++ b/providers/azure/api_mode_test.go @@ -0,0 +1,66 @@ +package azure + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/require" +) + +const ( + allowlistedResponsesModelID = "gpt-4o" + nonAllowlistedModelID = "definitely-not-in-responses-allowlist" +) + +func TestWithAPIModeResponsesForcesResponsesImplementation(t *testing.T) { + t.Parallel() + + provider, err := New(WithAPIMode(APIModeResponses)) + require.NoError(t, err) + + lm, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID) + require.NoError(t, err) + require.Equal(t, "openai.responsesLanguageModel", reflect.TypeOf(lm).String()) +} + +func TestWithAPIModeChatCompletionsForcesChatImplementation(t *testing.T) { + t.Parallel() + + provider, err := New(WithAPIMode(APIModeChatCompletions)) + require.NoError(t, err) + + lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID) + require.NoError(t, err) + require.Equal(t, "openai.languageModel", reflect.TypeOf(lm).String()) +} + +func TestWithUseResponsesAPIMatchesExplicitAutoMode(t *testing.T) { + t.Parallel() + + legacyProvider, err := New(WithUseResponsesAPI()) + require.NoError(t, err) + autoProvider, err := New(WithAPIMode(APIModeAuto)) + require.NoError(t, err) + + for _, modelID := range []string{allowlistedResponsesModelID, nonAllowlistedModelID} { + modelID := modelID + t.Run(modelID, func(t *testing.T) { + t.Parallel() + + legacyLM, err := legacyProvider.LanguageModel(context.Background(), modelID) + require.NoError(t, err) + autoLM, err := autoProvider.LanguageModel(context.Background(), modelID) + require.NoError(t, err) + + require.Equal(t, reflect.TypeOf(legacyLM), reflect.TypeOf(autoLM)) + }) + } +} + +func TestWithAPIModeRejectsUnknownValue(t *testing.T) { + t.Parallel() + + _, err := New(WithAPIMode(APIMode("definitely-invalid"))) + require.EqualError(t, err, `invalid OpenAI API mode "definitely-invalid"`) +} diff --git a/providers/azure/azure.go b/providers/azure/azure.go index a68df7251..f55ab7189 100644 --- a/providers/azure/azure.go +++ b/providers/azure/azure.go @@ -27,6 +27,18 @@ const ( defaultAPIVersion = "2025-01-01-preview" ) +// APIMode selects which OpenAI-compatible API Fantasy should target for Azure. +type APIMode = openai.APIMode + +const ( + // APIModeAuto uses Fantasy's Responses API allowlist heuristic. + APIModeAuto = openai.APIModeAuto + // APIModeChatCompletions forces the Chat Completions implementation. + APIModeChatCompletions = openai.APIModeChatCompletions + // APIModeResponses forces the Responses implementation. + APIModeResponses = openai.APIModeResponses +) + // azureURLPattern matches Azure OpenAI endpoint URLs in various formats: // * https://resource-id.openai.azure.com; // * https://resource-id.openai.azure.com/; @@ -117,7 +129,14 @@ func WithUserAgent(ua string) Option { } } -// WithUseResponsesAPI configures the provider to use the responses API for models that support it. +// WithAPIMode deterministically selects the Azure OpenAI API mode. +func WithAPIMode(mode APIMode) Option { + return func(o *options) { + o.openaiOptions = append(o.openaiOptions, openai.WithAPIMode(mode)) + } +} + +// WithUseResponsesAPI configures the provider to use the responses API heuristic for models that support it. func WithUseResponsesAPI() Option { return func(o *options) { o.openaiOptions = append(o.openaiOptions, openai.WithUseResponsesAPI()) diff --git a/providers/openai/api_mode_test.go b/providers/openai/api_mode_test.go new file mode 100644 index 000000000..b68c7f020 --- /dev/null +++ b/providers/openai/api_mode_test.go @@ -0,0 +1,165 @@ +package openai + +import ( + "context" + "reflect" + "testing" + + "charm.land/fantasy" + sdkopenai "github.com/openai/openai-go/v2" + "github.com/stretchr/testify/require" +) + +const ( + allowlistedResponsesModelID = "gpt-4o" + nonAllowlistedModelID = "definitely-not-in-responses-allowlist" +) + +func TestLanguageModelDefaultModeUsesChatCompletions(t *testing.T) { + t.Parallel() + + provider, err := New() + require.NoError(t, err) + + lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID) + require.NoError(t, err) + require.IsType(t, languageModel{}, lm) +} + +func TestLanguageModelWithUseResponsesAPIRemainsHeuristicOnly(t *testing.T) { + t.Parallel() + + provider, err := New(WithUseResponsesAPI()) + require.NoError(t, err) + + responsesLM, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID) + require.NoError(t, err) + require.IsType(t, responsesLanguageModel{}, responsesLM) + + chatLM, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID) + require.NoError(t, err) + require.IsType(t, languageModel{}, chatLM) +} + +func TestLanguageModelWithAPIModeAutoMatchesLegacyHeuristic(t *testing.T) { + t.Parallel() + + legacyProvider, err := New(WithUseResponsesAPI()) + require.NoError(t, err) + autoProvider, err := New(WithAPIMode(APIModeAuto)) + require.NoError(t, err) + + for _, modelID := range []string{allowlistedResponsesModelID, nonAllowlistedModelID} { + modelID := modelID + t.Run(modelID, func(t *testing.T) { + t.Parallel() + + legacyLM, err := legacyProvider.LanguageModel(context.Background(), modelID) + require.NoError(t, err) + autoLM, err := autoProvider.LanguageModel(context.Background(), modelID) + require.NoError(t, err) + + require.Equal(t, reflect.TypeOf(legacyLM), reflect.TypeOf(autoLM)) + }) + } +} + +func TestLanguageModelWithAPIModeResponsesBypassesAllowlist(t *testing.T) { + t.Parallel() + + provider, err := New(WithAPIMode(APIModeResponses)) + require.NoError(t, err) + + lm, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID) + require.NoError(t, err) + require.IsType(t, responsesLanguageModel{}, lm) +} + +func TestLanguageModelWithAPIModeChatCompletionsBypassesAllowlist(t *testing.T) { + t.Parallel() + + provider, err := New(WithAPIMode(APIModeChatCompletions)) + require.NoError(t, err) + + lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID) + require.NoError(t, err) + require.IsType(t, languageModel{}, lm) +} + +func TestWithAPIModeRejectsUnknownValue(t *testing.T) { + t.Parallel() + + _, err := New(WithAPIMode(APIMode("definitely-invalid"))) + require.EqualError(t, err, `invalid OpenAI API mode "definitely-invalid"`) +} + +func TestChatProviderOptionsAcceptsChatOptions(t *testing.T) { + t.Parallel() + + user := "test-user" + params := &sdkopenai.ChatCompletionNewParams{} + warnings, err := DefaultPrepareCallFunc( + newLanguageModel("gpt-4o", Name, sdkopenai.Client{}), + params, + fantasy.Call{ProviderOptions: NewProviderOptions(&ProviderOptions{User: &user})}, + ) + require.NoError(t, err) + require.Empty(t, warnings) + require.True(t, params.User.Valid()) + require.Equal(t, user, params.User.Value) +} + +func TestChatProviderOptionsRejectsResponsesOptions(t *testing.T) { + t.Parallel() + + _, err := DefaultPrepareCallFunc( + newLanguageModel("gpt-4o", Name, sdkopenai.Client{}), + &sdkopenai.ChatCompletionNewParams{}, + fantasy.Call{ProviderOptions: NewResponsesProviderOptions(&ResponsesProviderOptions{})}, + ) + require.EqualError(t, err, "invalid argument: openai chat_completions API mode expects provider options *openai.ProviderOptions, got *openai.ResponsesProviderOptions") +} + +func TestResponsesProviderOptionsAcceptsResponsesOptions(t *testing.T) { + t.Parallel() + + user := "test-user" + model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto) + params, warnings, err := model.prepareParams(fantasy.Call{ + ProviderOptions: NewResponsesProviderOptions(&ResponsesProviderOptions{User: &user}), + }) + require.NoError(t, err) + require.Empty(t, warnings) + require.True(t, params.User.Valid()) + require.Equal(t, user, params.User.Value) +} + +func TestResponsesProviderOptionsRejectsChatOptions(t *testing.T) { + t.Parallel() + + model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto) + _, _, err := model.prepareParams(fantasy.Call{ + ProviderOptions: NewProviderOptions(&ProviderOptions{}), + }) + require.EqualError(t, err, "invalid argument: openai responses API mode expects provider options *openai.ResponsesProviderOptions, got *openai.ProviderOptions") +} + +func TestChatProviderOptionsAllowsNoProviderOptions(t *testing.T) { + t.Parallel() + + params := &sdkopenai.ChatCompletionNewParams{} + warnings, err := DefaultPrepareCallFunc(newLanguageModel("gpt-4o", Name, sdkopenai.Client{}), params, fantasy.Call{}) + require.NoError(t, err) + require.Empty(t, warnings) + require.False(t, params.User.Valid()) +} + +func TestResponsesProviderOptionsAllowsNoProviderOptions(t *testing.T) { + t.Parallel() + + model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto) + params, warnings, err := model.prepareParams(fantasy.Call{}) + require.NoError(t, err) + require.Empty(t, warnings) + require.True(t, params.Store.Valid()) +} diff --git a/providers/openai/language_model_hooks.go b/providers/openai/language_model_hooks.go index d58c1794f..6d30f5b0c 100644 --- a/providers/openai/language_model_hooks.go +++ b/providers/openai/language_model_hooks.go @@ -41,12 +41,9 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp return nil, nil } var warnings []fantasy.CallWarning - providerOptions := &ProviderOptions{} - if v, ok := call.ProviderOptions[Name]; ok { - providerOptions, ok = v.(*ProviderOptions) - if !ok { - return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"} - } + providerOptions, err := chatProviderOptions(call) + if err != nil { + return nil, err } if providerOptions.LogitBias != nil { diff --git a/providers/openai/openai.go b/providers/openai/openai.go index 928f3dd28..d2460224a 100644 --- a/providers/openai/openai.go +++ b/providers/openai/openai.go @@ -4,6 +4,7 @@ package openai import ( "cmp" "context" + "fmt" "maps" "charm.land/fantasy" @@ -19,6 +20,18 @@ const ( DefaultURL = "https://api.openai.com/v1" ) +// APIMode selects which OpenAI API Fantasy should target. +type APIMode string + +const ( + // APIModeAuto uses Fantasy's Responses API allowlist heuristic. + APIModeAuto APIMode = "auto" + // APIModeChatCompletions forces the Chat Completions implementation. + APIModeChatCompletions APIMode = "chat_completions" + // APIModeResponses forces the Responses implementation. + APIModeResponses APIMode = "responses" +) + type provider struct { options options } @@ -29,7 +42,8 @@ type options struct { organization string project string name string - useResponsesAPI bool + apiMode APIMode + apiModeSet bool headers map[string]string userAgent string client option.HTTPClient @@ -53,6 +67,11 @@ func New(opts ...Option) (fantasy.Provider, error) { providerOptions.baseURL = cmp.Or(providerOptions.baseURL, DefaultURL) providerOptions.name = cmp.Or(providerOptions.name, Name) + if providerOptions.apiModeSet { + if err := validateAPIMode(providerOptions.apiMode); err != nil { + return nil, err + } + } if providerOptions.organization != "" { providerOptions.headers["OpenAi-Organization"] = providerOptions.organization @@ -127,13 +146,19 @@ func WithLanguageModelOptions(opts ...LanguageModelOption) Option { } } -// WithUseResponsesAPI configures the provider to use the responses API for models that support it. -func WithUseResponsesAPI() Option { +// WithAPIMode deterministically selects the OpenAI API mode. +func WithAPIMode(mode APIMode) Option { return func(o *options) { - o.useResponsesAPI = true + o.apiMode = mode + o.apiModeSet = true } } +// WithUseResponsesAPI configures the provider to use the responses API heuristic for models that support it. +func WithUseResponsesAPI() Option { + return WithAPIMode(APIModeAuto) +} + // WithUserAgent sets an explicit User-Agent header, overriding the default and any // value set via WithHeaders. func WithUserAgent(ua string) Option { @@ -179,23 +204,63 @@ func (o *provider) LanguageModel(_ context.Context, modelID string) (fantasy.Lan client := openai.NewClient(openaiClientOptions...) - if o.options.useResponsesAPI && IsResponsesModel(modelID) { + apiMode, err := o.options.effectiveAPIMode(modelID) + if err != nil { + return nil, err + } + + switch apiMode { + case APIModeResponses: // Not supported for responses API objectMode := o.options.objectMode if objectMode == fantasy.ObjectModeJSON { objectMode = fantasy.ObjectModeAuto } return newResponsesLanguageModel(modelID, o.options.name, client, objectMode), nil + case APIModeChatCompletions: + languageModelOptions := append([]LanguageModelOption{}, o.options.languageModelOptions...) + languageModelOptions = append(languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode)) + + return newLanguageModel( + modelID, + o.options.name, + client, + languageModelOptions..., + ), nil + default: + return nil, fmt.Errorf("internal error: unhandled OpenAI API mode %q", apiMode) + } +} + +func validateAPIMode(mode APIMode) error { + switch mode { + case APIModeAuto, APIModeChatCompletions, APIModeResponses: + return nil + default: + return fmt.Errorf("invalid OpenAI API mode %q", mode) + } +} + +func (o options) effectiveAPIMode(modelID string) (APIMode, error) { + if !o.apiModeSet { + return APIModeChatCompletions, nil } - o.options.languageModelOptions = append(o.options.languageModelOptions, WithLanguageModelObjectMode(o.options.objectMode)) + if err := validateAPIMode(o.apiMode); err != nil { + return "", err + } - return newLanguageModel( - modelID, - o.options.name, - client, - o.options.languageModelOptions..., - ), nil + switch o.apiMode { + case APIModeAuto: + if IsResponsesModel(modelID) { + return APIModeResponses, nil + } + return APIModeChatCompletions, nil + case APIModeChatCompletions, APIModeResponses: + return o.apiMode, nil + default: + return "", fmt.Errorf("internal error: unhandled configured OpenAI API mode %q", o.apiMode) + } } func (o *provider) Name() string { diff --git a/providers/openai/provider_options_guard.go b/providers/openai/provider_options_guard.go new file mode 100644 index 000000000..6c3be4752 --- /dev/null +++ b/providers/openai/provider_options_guard.go @@ -0,0 +1,49 @@ +package openai + +import ( + "fmt" + + "charm.land/fantasy" +) + +func chatProviderOptions(call fantasy.Call) (*ProviderOptions, error) { + if call.ProviderOptions == nil { + return &ProviderOptions{}, nil + } + + providerOptions := &ProviderOptions{} + if value, ok := call.ProviderOptions[Name]; ok { + typed, ok := value.(*ProviderOptions) + if !ok { + return nil, providerOptionsTypeError(APIModeChatCompletions, "*openai.ProviderOptions", value) + } + providerOptions = typed + } + + return providerOptions, nil +} + +func responsesProviderOptions(call fantasy.Call) (*ResponsesProviderOptions, error) { + if call.ProviderOptions == nil { + return nil, nil + } + + value, ok := call.ProviderOptions[Name] + if !ok { + return nil, nil + } + + typed, ok := value.(*ResponsesProviderOptions) + if !ok { + return nil, providerOptionsTypeError(APIModeResponses, "*openai.ResponsesProviderOptions", value) + } + + return typed, nil +} + +func providerOptionsTypeError(mode APIMode, expectedType string, actual any) error { + return &fantasy.Error{ + Title: "invalid argument", + Message: fmt.Sprintf("openai %s API mode expects provider options %s, got %T", mode, expectedType, actual), + } +} diff --git a/providers/openai/responses_language_model.go b/providers/openai/responses_language_model.go index 090117de0..3cbdb484f 100644 --- a/providers/openai/responses_language_model.go +++ b/providers/openai/responses_language_model.go @@ -119,7 +119,7 @@ func getResponsesModelConfig(modelID string) responsesModelConfig { } } -func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.ResponseNewParams, []fantasy.CallWarning) { +func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.ResponseNewParams, []fantasy.CallWarning, error) { var warnings []fantasy.CallWarning params := &responses.ResponseNewParams{ Store: param.NewOpt(false), @@ -148,11 +148,9 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res }) } - var openaiOptions *ResponsesProviderOptions - if opts, ok := call.ProviderOptions[Name]; ok { - if typedOpts, ok := opts.(*ResponsesProviderOptions); ok { - openaiOptions = typedOpts - } + openaiOptions, err := responsesProviderOptions(call) + if err != nil { + return nil, nil, err } input, inputWarnings := toResponsesPrompt(call.Prompt, modelConfig.systemMessageMode) @@ -326,7 +324,7 @@ func (o responsesLanguageModel) prepareParams(call fantasy.Call) (*responses.Res params.ToolChoice = toolChoice } - return params, warnings + return params, warnings, nil } func toResponsesPrompt(prompt fantasy.Prompt, systemMessageMode string) (responses.ResponseInputParam, []fantasy.CallWarning) { @@ -667,7 +665,10 @@ func toResponsesTools(tools []fantasy.Tool, toolChoice *fantasy.ToolChoice, opti } func (o responsesLanguageModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { - params, warnings := o.prepareParams(call) + params, warnings, err := o.prepareParams(call) + if err != nil { + return nil, err + } response, err := o.client.Responses.New(ctx, *params, callUARequestOptions(call)...) if err != nil { return nil, toProviderErr(err) @@ -804,7 +805,10 @@ func mapResponsesFinishReason(reason string, hasFunctionCall bool) fantasy.Finis } func (o responsesLanguageModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - params, warnings := o.prepareParams(call) + params, warnings, err := o.prepareParams(call) + if err != nil { + return nil, err + } stream := o.client.Responses.NewStreaming(ctx, *params, callUARequestOptions(call)...) @@ -1098,7 +1102,10 @@ func (o responsesLanguageModel) generateObjectWithJSONMode(ctx context.Context, ProviderOptions: call.ProviderOptions, } - params, warnings := o.prepareParams(fantasyCall) + params, warnings, err := o.prepareParams(fantasyCall) + if err != nil { + return nil, err + } // Add structured output via Text.Format field params.Text = responses.ResponseTextConfigParam{ @@ -1209,7 +1216,10 @@ func (o responsesLanguageModel) streamObjectWithJSONMode(ctx context.Context, ca ProviderOptions: call.ProviderOptions, } - params, warnings := o.prepareParams(fantasyCall) + params, warnings, err := o.prepareParams(fantasyCall) + if err != nil { + return nil, err + } // Add structured output via Text.Format field params.Text = responses.ResponseTextConfigParam{ diff --git a/providers/openai/responses_options.go b/providers/openai/responses_options.go index 41ca2f67e..6a559563e 100644 --- a/providers/openai/responses_options.go +++ b/providers/openai/responses_options.go @@ -214,7 +214,7 @@ func ParseResponsesOptions(data map[string]any) (*ResponsesProviderOptions, erro return &options, nil } -// IsResponsesModel checks if a model ID is a Responses API model for OpenAI. +// IsResponsesModel checks if a model ID is on Fantasy's OpenAI Responses API allowlist heuristic. func IsResponsesModel(modelID string) bool { return slices.Contains(responsesModelIDs, modelID) }