package openai_test import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/provider/openai" ) // newTestAdapter builds an Adapter pointed at the given mock server URL. func newTestAdapter(serverURL string) *openai.Adapter { return openai.New(openai.Config{ APIKey: "test-key", BaseURL: serverURL, TimeoutSeconds: 5, MaxConns: 10, }) } // minimalRequest returns a valid ChatRequest for tests. func minimalRequest() *provider.ChatRequest { return &provider.ChatRequest{ Model: "gpt-4o", Messages: []provider.Message{{Role: "user", Content: "hello"}}, } } // ─── Send ──────────────────────────────────────────────────────────────────── func TestAdapter_Send_Nominal(t *testing.T) { expected := provider.ChatResponse{ ID: "chatcmpl-test", Object: "chat.completion", Model: "gpt-4o", Choices: []provider.Choice{{ Index: 0, Message: provider.Message{Role: "assistant", Content: "Hello!"}, FinishReason: "stop", }}, Usage: provider.Usage{PromptTokens: 5, CompletionTokens: 3, TotalTokens: 8}, } srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/chat/completions", r.URL.Path) assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(expected) })) defer srv.Close() adapter := newTestAdapter(srv.URL) got, err := adapter.Send(context.Background(), minimalRequest()) require.NoError(t, err) require.NotNil(t, got) assert.Equal(t, expected.ID, got.ID) assert.Equal(t, "Hello!", got.Choices[0].Message.Content) } func TestAdapter_Send_Upstream401_ReturnsAuthError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) })) defer srv.Close() adapter := newTestAdapter(srv.URL) _, err := adapter.Send(context.Background(), minimalRequest()) require.Error(t, err) // The error message must contain "API key". assert.Contains(t, err.Error(), "API key") } // ─── Stream ────────────────────────────────────────────────────────────────── func TestAdapter_Stream_Nominal(t *testing.T) { // Simulate an OpenAI SSE response with two chunks and the [DONE] sentinel. ssePayload := strings.Join([]string{ `data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"Hel"}}]}`, `data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"lo"}}]}`, `data: [DONE]`, "", }, "\n") srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/chat/completions", r.URL.Path) w.Header().Set("Content-Type", "text/event-stream") w.WriteHeader(http.StatusOK) fmt.Fprint(w, ssePayload) })) defer srv.Close() adapter := newTestAdapter(srv.URL) rec := httptest.NewRecorder() // Set SSE headers as the proxy handler would. rec.Header().Set("Content-Type", "text/event-stream") rec.Header().Set("Cache-Control", "no-cache") req := minimalRequest() req.Stream = true err := adapter.Stream(context.Background(), req, rec) require.NoError(t, err) body := rec.Body.String() assert.Contains(t, body, `"Hel"`) assert.Contains(t, body, `"lo"`) assert.Contains(t, body, "data: [DONE]") } // ─── HealthCheck ───────────────────────────────────────────────────────────── func TestAdapter_HealthCheck_Nominal(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { assert.Equal(t, "/models", r.URL.Path) w.WriteHeader(http.StatusOK) _, _ = w.Write([]byte(`{"data":[]}`)) })) defer srv.Close() adapter := newTestAdapter(srv.URL) err := adapter.HealthCheck(context.Background()) require.NoError(t, err) } func TestAdapter_Send_Upstream429_ReturnsRateLimit(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusTooManyRequests) })) defer srv.Close() adapter := newTestAdapter(srv.URL) _, err := adapter.Send(context.Background(), minimalRequest()) require.Error(t, err) assert.Contains(t, err.Error(), "rate limit") } func TestAdapter_Send_Upstream400_ReturnsBadRequest(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer srv.Close() adapter := newTestAdapter(srv.URL) _, err := adapter.Send(context.Background(), minimalRequest()) require.Error(t, err) assert.Contains(t, err.Error(), "rejected") } func TestAdapter_Send_Upstream500_ReturnsUpstreamError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusInternalServerError) })) defer srv.Close() adapter := newTestAdapter(srv.URL) _, err := adapter.Send(context.Background(), minimalRequest()) require.Error(t, err) assert.Contains(t, err.Error(), "500") } func TestAdapter_Stream_Upstream401_ReturnsError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusUnauthorized) })) defer srv.Close() adapter := newTestAdapter(srv.URL) req := minimalRequest() req.Stream = true rec := httptest.NewRecorder() err := adapter.Stream(context.Background(), req, rec) require.Error(t, err) } func TestAdapter_HealthCheck_FailureReturnsError(t *testing.T) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusServiceUnavailable) })) defer srv.Close() adapter := newTestAdapter(srv.URL) err := adapter.HealthCheck(context.Background()) require.Error(t, err) assert.Contains(t, err.Error(), "503") } // ─── Validate ──────────────────────────────────────────────────────────────── func TestAdapter_Validate_MissingModel(t *testing.T) { adapter := newTestAdapter("http://unused") err := adapter.Validate(&provider.ChatRequest{ Messages: []provider.Message{{Role: "user", Content: "hi"}}, }) require.Error(t, err) assert.Contains(t, err.Error(), "model") } func TestAdapter_Validate_EmptyMessages(t *testing.T) { adapter := newTestAdapter("http://unused") err := adapter.Validate(&provider.ChatRequest{Model: "gpt-4o"}) require.Error(t, err) assert.Contains(t, err.Error(), "messages") }