veylant/internal/provider/openai/adapter_test.go
2026-02-23 13:35:04 +01:00

220 lines
6.9 KiB
Go

package openai_test
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/provider/openai"
)
// newTestAdapter builds an Adapter pointed at the given mock server URL.
func newTestAdapter(serverURL string) *openai.Adapter {
return openai.New(openai.Config{
APIKey: "test-key",
BaseURL: serverURL,
TimeoutSeconds: 5,
MaxConns: 10,
})
}
// minimalRequest returns a valid ChatRequest for tests.
func minimalRequest() *provider.ChatRequest {
return &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hello"}},
}
}
// ─── Send ────────────────────────────────────────────────────────────────────
func TestAdapter_Send_Nominal(t *testing.T) {
expected := provider.ChatResponse{
ID: "chatcmpl-test",
Object: "chat.completion",
Model: "gpt-4o",
Choices: []provider.Choice{{
Index: 0,
Message: provider.Message{Role: "assistant", Content: "Hello!"},
FinishReason: "stop",
}},
Usage: provider.Usage{PromptTokens: 5, CompletionTokens: 3, TotalTokens: 8},
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path)
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(expected)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
got, err := adapter.Send(context.Background(), minimalRequest())
require.NoError(t, err)
require.NotNil(t, got)
assert.Equal(t, expected.ID, got.ID)
assert.Equal(t, "Hello!", got.Choices[0].Message.Content)
}
func TestAdapter_Send_Upstream401_ReturnsAuthError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
_, err := adapter.Send(context.Background(), minimalRequest())
require.Error(t, err)
// The error message must contain "API key".
assert.Contains(t, err.Error(), "API key")
}
// ─── Stream ──────────────────────────────────────────────────────────────────
func TestAdapter_Stream_Nominal(t *testing.T) {
// Simulate an OpenAI SSE response with two chunks and the [DONE] sentinel.
ssePayload := strings.Join([]string{
`data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"Hel"}}]}`,
`data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"lo"}}]}`,
`data: [DONE]`,
"",
}, "\n")
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/chat/completions", r.URL.Path)
w.Header().Set("Content-Type", "text/event-stream")
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, ssePayload)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
rec := httptest.NewRecorder()
// Set SSE headers as the proxy handler would.
rec.Header().Set("Content-Type", "text/event-stream")
rec.Header().Set("Cache-Control", "no-cache")
req := minimalRequest()
req.Stream = true
err := adapter.Stream(context.Background(), req, rec)
require.NoError(t, err)
body := rec.Body.String()
assert.Contains(t, body, `"Hel"`)
assert.Contains(t, body, `"lo"`)
assert.Contains(t, body, "data: [DONE]")
}
// ─── HealthCheck ─────────────────────────────────────────────────────────────
func TestAdapter_HealthCheck_Nominal(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "/models", r.URL.Path)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"data":[]}`))
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
err := adapter.HealthCheck(context.Background())
require.NoError(t, err)
}
func TestAdapter_Send_Upstream429_ReturnsRateLimit(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusTooManyRequests)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
_, err := adapter.Send(context.Background(), minimalRequest())
require.Error(t, err)
assert.Contains(t, err.Error(), "rate limit")
}
func TestAdapter_Send_Upstream400_ReturnsBadRequest(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
_, err := adapter.Send(context.Background(), minimalRequest())
require.Error(t, err)
assert.Contains(t, err.Error(), "rejected")
}
func TestAdapter_Send_Upstream500_ReturnsUpstreamError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
_, err := adapter.Send(context.Background(), minimalRequest())
require.Error(t, err)
assert.Contains(t, err.Error(), "500")
}
func TestAdapter_Stream_Upstream401_ReturnsError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
req := minimalRequest()
req.Stream = true
rec := httptest.NewRecorder()
err := adapter.Stream(context.Background(), req, rec)
require.Error(t, err)
}
func TestAdapter_HealthCheck_FailureReturnsError(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusServiceUnavailable)
}))
defer srv.Close()
adapter := newTestAdapter(srv.URL)
err := adapter.HealthCheck(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "503")
}
// ─── Validate ────────────────────────────────────────────────────────────────
func TestAdapter_Validate_MissingModel(t *testing.T) {
adapter := newTestAdapter("http://unused")
err := adapter.Validate(&provider.ChatRequest{
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "model")
}
func TestAdapter_Validate_EmptyMessages(t *testing.T) {
adapter := newTestAdapter("http://unused")
err := adapter.Validate(&provider.ChatRequest{Model: "gpt-4o"})
require.Error(t, err)
assert.Contains(t, err.Error(), "messages")
}