Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions providers/azure/api_mode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package azure

import (
"context"
"reflect"
"testing"

"github.com/stretchr/testify/require"
)

const (
allowlistedResponsesModelID = "gpt-4o"
nonAllowlistedModelID = "definitely-not-in-responses-allowlist"
)

func TestWithAPIModeResponsesForcesResponsesImplementation(t *testing.T) {
t.Parallel()

provider, err := New(WithAPIMode(APIModeResponses))
require.NoError(t, err)

lm, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID)
require.NoError(t, err)
require.Equal(t, "openai.responsesLanguageModel", reflect.TypeOf(lm).String())
}

func TestWithAPIModeChatCompletionsForcesChatImplementation(t *testing.T) {
t.Parallel()

provider, err := New(WithAPIMode(APIModeChatCompletions))
require.NoError(t, err)

lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID)
require.NoError(t, err)
require.Equal(t, "openai.languageModel", reflect.TypeOf(lm).String())
}

func TestWithUseResponsesAPIMatchesExplicitAutoMode(t *testing.T) {
t.Parallel()

legacyProvider, err := New(WithUseResponsesAPI())
require.NoError(t, err)
autoProvider, err := New(WithAPIMode(APIModeAuto))
require.NoError(t, err)

for _, modelID := range []string{allowlistedResponsesModelID, nonAllowlistedModelID} {
modelID := modelID
t.Run(modelID, func(t *testing.T) {
t.Parallel()

legacyLM, err := legacyProvider.LanguageModel(context.Background(), modelID)
require.NoError(t, err)
autoLM, err := autoProvider.LanguageModel(context.Background(), modelID)
require.NoError(t, err)

require.Equal(t, reflect.TypeOf(legacyLM), reflect.TypeOf(autoLM))
})
}
}

func TestWithAPIModeRejectsUnknownValue(t *testing.T) {
t.Parallel()

_, err := New(WithAPIMode(APIMode("definitely-invalid")))
require.EqualError(t, err, `invalid OpenAI API mode "definitely-invalid"`)
}
21 changes: 20 additions & 1 deletion providers/azure/azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ const (
defaultAPIVersion = "2025-01-01-preview"
)

// APIMode selects which OpenAI-compatible API Fantasy should target for Azure.
type APIMode = openai.APIMode

const (
// APIModeAuto uses Fantasy's Responses API allowlist heuristic.
APIModeAuto = openai.APIModeAuto
// APIModeChatCompletions forces the Chat Completions implementation.
APIModeChatCompletions = openai.APIModeChatCompletions
// APIModeResponses forces the Responses implementation.
APIModeResponses = openai.APIModeResponses
)

// azureURLPattern matches Azure OpenAI endpoint URLs in various formats:
// * https://resource-id.openai.azure.com;
// * https://resource-id.openai.azure.com/;
Expand Down Expand Up @@ -117,7 +129,14 @@ func WithUserAgent(ua string) Option {
}
}

// WithUseResponsesAPI configures the provider to use the responses API for models that support it.
// WithAPIMode deterministically selects the Azure OpenAI API mode.
func WithAPIMode(mode APIMode) Option {
return func(o *options) {
o.openaiOptions = append(o.openaiOptions, openai.WithAPIMode(mode))
}
}

// WithUseResponsesAPI configures the provider to use the responses API heuristic for models that support it.
func WithUseResponsesAPI() Option {
return func(o *options) {
o.openaiOptions = append(o.openaiOptions, openai.WithUseResponsesAPI())
Expand Down
165 changes: 165 additions & 0 deletions providers/openai/api_mode_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
package openai

import (
"context"
"reflect"
"testing"

"charm.land/fantasy"
sdkopenai "github.com/openai/openai-go/v2"
"github.com/stretchr/testify/require"
)

const (
allowlistedResponsesModelID = "gpt-4o"
nonAllowlistedModelID = "definitely-not-in-responses-allowlist"
)

func TestLanguageModelDefaultModeUsesChatCompletions(t *testing.T) {
t.Parallel()

provider, err := New()
require.NoError(t, err)

lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID)
require.NoError(t, err)
require.IsType(t, languageModel{}, lm)
}

func TestLanguageModelWithUseResponsesAPIRemainsHeuristicOnly(t *testing.T) {
t.Parallel()

provider, err := New(WithUseResponsesAPI())
require.NoError(t, err)

responsesLM, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID)
require.NoError(t, err)
require.IsType(t, responsesLanguageModel{}, responsesLM)

chatLM, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID)
require.NoError(t, err)
require.IsType(t, languageModel{}, chatLM)
}

func TestLanguageModelWithAPIModeAutoMatchesLegacyHeuristic(t *testing.T) {
t.Parallel()

legacyProvider, err := New(WithUseResponsesAPI())
require.NoError(t, err)
autoProvider, err := New(WithAPIMode(APIModeAuto))
require.NoError(t, err)

for _, modelID := range []string{allowlistedResponsesModelID, nonAllowlistedModelID} {
modelID := modelID
t.Run(modelID, func(t *testing.T) {
t.Parallel()

legacyLM, err := legacyProvider.LanguageModel(context.Background(), modelID)
require.NoError(t, err)
autoLM, err := autoProvider.LanguageModel(context.Background(), modelID)
require.NoError(t, err)

require.Equal(t, reflect.TypeOf(legacyLM), reflect.TypeOf(autoLM))
})
}
}

func TestLanguageModelWithAPIModeResponsesBypassesAllowlist(t *testing.T) {
t.Parallel()

provider, err := New(WithAPIMode(APIModeResponses))
require.NoError(t, err)

lm, err := provider.LanguageModel(context.Background(), nonAllowlistedModelID)
require.NoError(t, err)
require.IsType(t, responsesLanguageModel{}, lm)
}

func TestLanguageModelWithAPIModeChatCompletionsBypassesAllowlist(t *testing.T) {
t.Parallel()

provider, err := New(WithAPIMode(APIModeChatCompletions))
require.NoError(t, err)

lm, err := provider.LanguageModel(context.Background(), allowlistedResponsesModelID)
require.NoError(t, err)
require.IsType(t, languageModel{}, lm)
}

func TestWithAPIModeRejectsUnknownValue(t *testing.T) {
t.Parallel()

_, err := New(WithAPIMode(APIMode("definitely-invalid")))
require.EqualError(t, err, `invalid OpenAI API mode "definitely-invalid"`)
}

func TestChatProviderOptionsAcceptsChatOptions(t *testing.T) {
t.Parallel()

user := "test-user"
params := &sdkopenai.ChatCompletionNewParams{}
warnings, err := DefaultPrepareCallFunc(
newLanguageModel("gpt-4o", Name, sdkopenai.Client{}),
params,
fantasy.Call{ProviderOptions: NewProviderOptions(&ProviderOptions{User: &user})},
)
require.NoError(t, err)
require.Empty(t, warnings)
require.True(t, params.User.Valid())
require.Equal(t, user, params.User.Value)
}

func TestChatProviderOptionsRejectsResponsesOptions(t *testing.T) {
t.Parallel()

_, err := DefaultPrepareCallFunc(
newLanguageModel("gpt-4o", Name, sdkopenai.Client{}),
&sdkopenai.ChatCompletionNewParams{},
fantasy.Call{ProviderOptions: NewResponsesProviderOptions(&ResponsesProviderOptions{})},
)
require.EqualError(t, err, "invalid argument: openai chat_completions API mode expects provider options *openai.ProviderOptions, got *openai.ResponsesProviderOptions")
}

func TestResponsesProviderOptionsAcceptsResponsesOptions(t *testing.T) {
t.Parallel()

user := "test-user"
model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto)
params, warnings, err := model.prepareParams(fantasy.Call{
ProviderOptions: NewResponsesProviderOptions(&ResponsesProviderOptions{User: &user}),
})
require.NoError(t, err)
require.Empty(t, warnings)
require.True(t, params.User.Valid())
require.Equal(t, user, params.User.Value)
}

func TestResponsesProviderOptionsRejectsChatOptions(t *testing.T) {
t.Parallel()

model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto)
_, _, err := model.prepareParams(fantasy.Call{
ProviderOptions: NewProviderOptions(&ProviderOptions{}),
})
require.EqualError(t, err, "invalid argument: openai responses API mode expects provider options *openai.ResponsesProviderOptions, got *openai.ProviderOptions")
}

func TestChatProviderOptionsAllowsNoProviderOptions(t *testing.T) {
t.Parallel()

params := &sdkopenai.ChatCompletionNewParams{}
warnings, err := DefaultPrepareCallFunc(newLanguageModel("gpt-4o", Name, sdkopenai.Client{}), params, fantasy.Call{})
require.NoError(t, err)
require.Empty(t, warnings)
require.False(t, params.User.Valid())
}

func TestResponsesProviderOptionsAllowsNoProviderOptions(t *testing.T) {
t.Parallel()

model := newResponsesLanguageModel("gpt-4o", Name, sdkopenai.Client{}, fantasy.ObjectModeAuto)
params, warnings, err := model.prepareParams(fantasy.Call{})
require.NoError(t, err)
require.Empty(t, warnings)
require.True(t, params.Store.Valid())
}
9 changes: 3 additions & 6 deletions providers/openai/language_model_hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,9 @@ func DefaultPrepareCallFunc(model fantasy.LanguageModel, params *openai.ChatComp
return nil, nil
}
var warnings []fantasy.CallWarning
providerOptions := &ProviderOptions{}
if v, ok := call.ProviderOptions[Name]; ok {
providerOptions, ok = v.(*ProviderOptions)
if !ok {
return nil, &fantasy.Error{Title: "invalid argument", Message: "openai provider options should be *openai.ProviderOptions"}
}
providerOptions, err := chatProviderOptions(call)
if err != nil {
return nil, err
}

if providerOptions.LogitBias != nil {
Expand Down
Loading