220 lines
6.9 KiB
Go
220 lines
6.9 KiB
Go
package openai_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/provider/openai"
|
|
)
|
|
|
|
// newTestAdapter builds an Adapter pointed at the given mock server URL.
|
|
func newTestAdapter(serverURL string) *openai.Adapter {
|
|
return openai.New(openai.Config{
|
|
APIKey: "test-key",
|
|
BaseURL: serverURL,
|
|
TimeoutSeconds: 5,
|
|
MaxConns: 10,
|
|
})
|
|
}
|
|
|
|
// minimalRequest returns a valid ChatRequest for tests.
|
|
func minimalRequest() *provider.ChatRequest {
|
|
return &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hello"}},
|
|
}
|
|
}
|
|
|
|
// ─── Send ────────────────────────────────────────────────────────────────────
|
|
|
|
func TestAdapter_Send_Nominal(t *testing.T) {
|
|
expected := provider.ChatResponse{
|
|
ID: "chatcmpl-test",
|
|
Object: "chat.completion",
|
|
Model: "gpt-4o",
|
|
Choices: []provider.Choice{{
|
|
Index: 0,
|
|
Message: provider.Message{Role: "assistant", Content: "Hello!"},
|
|
FinishReason: "stop",
|
|
}},
|
|
Usage: provider.Usage{PromptTokens: 5, CompletionTokens: 3, TotalTokens: 8},
|
|
}
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/chat/completions", r.URL.Path)
|
|
assert.Equal(t, "Bearer test-key", r.Header.Get("Authorization"))
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(expected)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
got, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.NoError(t, err)
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, expected.ID, got.ID)
|
|
assert.Equal(t, "Hello!", got.Choices[0].Message.Content)
|
|
}
|
|
|
|
func TestAdapter_Send_Upstream401_ReturnsAuthError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
_, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.Error(t, err)
|
|
// The error message must contain "API key".
|
|
assert.Contains(t, err.Error(), "API key")
|
|
}
|
|
|
|
// ─── Stream ──────────────────────────────────────────────────────────────────
|
|
|
|
func TestAdapter_Stream_Nominal(t *testing.T) {
|
|
// Simulate an OpenAI SSE response with two chunks and the [DONE] sentinel.
|
|
ssePayload := strings.Join([]string{
|
|
`data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"Hel"}}]}`,
|
|
`data: {"id":"chatcmpl-1","choices":[{"delta":{"content":"lo"}}]}`,
|
|
`data: [DONE]`,
|
|
"",
|
|
}, "\n")
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/chat/completions", r.URL.Path)
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
fmt.Fprint(w, ssePayload)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
|
|
rec := httptest.NewRecorder()
|
|
// Set SSE headers as the proxy handler would.
|
|
rec.Header().Set("Content-Type", "text/event-stream")
|
|
rec.Header().Set("Cache-Control", "no-cache")
|
|
|
|
req := minimalRequest()
|
|
req.Stream = true
|
|
err := adapter.Stream(context.Background(), req, rec)
|
|
|
|
require.NoError(t, err)
|
|
body := rec.Body.String()
|
|
assert.Contains(t, body, `"Hel"`)
|
|
assert.Contains(t, body, `"lo"`)
|
|
assert.Contains(t, body, "data: [DONE]")
|
|
}
|
|
|
|
// ─── HealthCheck ─────────────────────────────────────────────────────────────
|
|
|
|
func TestAdapter_HealthCheck_Nominal(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "/models", r.URL.Path)
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"data":[]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
err := adapter.HealthCheck(context.Background())
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAdapter_Send_Upstream429_ReturnsRateLimit(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
_, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "rate limit")
|
|
}
|
|
|
|
func TestAdapter_Send_Upstream400_ReturnsBadRequest(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusBadRequest)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
_, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "rejected")
|
|
}
|
|
|
|
func TestAdapter_Send_Upstream500_ReturnsUpstreamError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
_, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "500")
|
|
}
|
|
|
|
func TestAdapter_Stream_Upstream401_ReturnsError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
req := minimalRequest()
|
|
req.Stream = true
|
|
|
|
rec := httptest.NewRecorder()
|
|
err := adapter.Stream(context.Background(), req, rec)
|
|
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestAdapter_HealthCheck_FailureReturnsError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(srv.URL)
|
|
err := adapter.HealthCheck(context.Background())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "503")
|
|
}
|
|
|
|
// ─── Validate ────────────────────────────────────────────────────────────────
|
|
|
|
func TestAdapter_Validate_MissingModel(t *testing.T) {
|
|
adapter := newTestAdapter("http://unused")
|
|
err := adapter.Validate(&provider.ChatRequest{
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "model")
|
|
}
|
|
|
|
func TestAdapter_Validate_EmptyMessages(t *testing.T) {
|
|
adapter := newTestAdapter("http://unused")
|
|
err := adapter.Validate(&provider.ChatRequest{Model: "gpt-4o"})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "messages")
|
|
}
|
|
|