182 lines
6.5 KiB
Go
182 lines
6.5 KiB
Go
package azure_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/provider/azure"
|
|
)
|
|
|
|
// newTestAdapter builds an Adapter using the given mock server URL.
|
|
// We override the URL by setting ResourceName to the server hostname so the
|
|
// adapter constructs the right endpoint in tests.
|
|
func newTestAdapter(t *testing.T, serverURL string) *azure.Adapter {
|
|
t.Helper()
|
|
// Strip scheme from serverURL to get host (e.g. "127.0.0.1:PORT")
|
|
host := strings.TrimPrefix(serverURL, "http://")
|
|
// We can't directly set a custom base URL in Azure adapter, so we use a
|
|
// custom config and inject via the exported ChatURL helper via tests.
|
|
// For tests we create a mock that validates expected request patterns.
|
|
_ = host
|
|
return azure.NewWithBaseURL(serverURL, azure.Config{
|
|
APIKey: "test-key",
|
|
ResourceName: "test-resource",
|
|
DeploymentID: "gpt-4o",
|
|
APIVersion: "2024-02-01",
|
|
TimeoutSeconds: 5,
|
|
MaxConns: 10,
|
|
})
|
|
}
|
|
|
|
func minimalRequest() *provider.ChatRequest {
|
|
return &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hello"}},
|
|
}
|
|
}
|
|
|
|
// ─── Send ────────────────────────────────────────────────────────────────────
|
|
|
|
func TestAzureAdapter_Send_Nominal(t *testing.T) {
|
|
expected := provider.ChatResponse{
|
|
ID: "chatcmpl-azure-test",
|
|
Object: "chat.completion",
|
|
Model: "gpt-4o",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "Hello from Azure!"},
|
|
FinishReason: "stop",
|
|
}},
|
|
}
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
// Verify Azure-specific header (api-key, NOT Authorization Bearer)
|
|
assert.Equal(t, "test-key", r.Header.Get("api-key"))
|
|
assert.Empty(t, r.Header.Get("Authorization"))
|
|
// Verify URL contains api-version
|
|
assert.Contains(t, r.URL.RawQuery, "api-version=2024-02-01")
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
_ = json.NewEncoder(w).Encode(expected)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
adapter := newTestAdapter(t, srv.URL)
|
|
got, err := adapter.Send(context.Background(), minimalRequest())
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "chatcmpl-azure-test", got.ID)
|
|
assert.Equal(t, "Hello from Azure!", got.Choices[0].Message.Content)
|
|
}
|
|
|
|
func TestAzureAdapter_Send_Upstream401_ReturnsAuthError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
_, err := newTestAdapter(t, srv.URL).Send(context.Background(), minimalRequest())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "API key")
|
|
}
|
|
|
|
func TestAzureAdapter_Send_Upstream429_ReturnsRateLimit(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusTooManyRequests)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
_, err := newTestAdapter(t, srv.URL).Send(context.Background(), minimalRequest())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "rate limit")
|
|
}
|
|
|
|
func TestAzureAdapter_Send_Upstream500_ReturnsUpstreamError(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
_, err := newTestAdapter(t, srv.URL).Send(context.Background(), minimalRequest())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "500")
|
|
}
|
|
|
|
// ─── Stream ──────────────────────────────────────────────────────────────────
|
|
|
|
func TestAzureAdapter_Stream_Nominal(t *testing.T) {
|
|
ssePayload := "data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"delta\":{\"content\":\"Hel\"}}]}\n" +
|
|
"data: {\"id\":\"chatcmpl-1\",\"choices\":[{\"delta\":{\"content\":\"lo\"}}]}\n" +
|
|
"data: [DONE]\n"
|
|
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(ssePayload))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
rec := httptest.NewRecorder()
|
|
rec.Header().Set("Content-Type", "text/event-stream")
|
|
|
|
req := minimalRequest()
|
|
req.Stream = true
|
|
err := newTestAdapter(t, srv.URL).Stream(context.Background(), req, rec)
|
|
|
|
require.NoError(t, err)
|
|
body := rec.Body.String()
|
|
assert.Contains(t, body, "Hel")
|
|
assert.Contains(t, body, "[DONE]")
|
|
}
|
|
|
|
// ─── Validate ────────────────────────────────────────────────────────────────
|
|
|
|
func TestAzureAdapter_Validate_MissingModel(t *testing.T) {
|
|
a := azure.NewWithBaseURL("http://unused", azure.Config{APIKey: "k"})
|
|
err := a.Validate(&provider.ChatRequest{
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "model")
|
|
}
|
|
|
|
func TestAzureAdapter_Validate_EmptyMessages(t *testing.T) {
|
|
a := azure.NewWithBaseURL("http://unused", azure.Config{APIKey: "k"})
|
|
err := a.Validate(&provider.ChatRequest{Model: "gpt-4o"})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "messages")
|
|
}
|
|
|
|
// ─── HealthCheck ─────────────────────────────────────────────────────────────
|
|
|
|
func TestAzureAdapter_HealthCheck_Nominal(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, "test-key", r.Header.Get("api-key"))
|
|
assert.Contains(t, r.URL.RawQuery, "api-version=2024-02-01")
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(`{"data":[]}`))
|
|
}))
|
|
defer srv.Close()
|
|
|
|
err := newTestAdapter(t, srv.URL).HealthCheck(context.Background())
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestAzureAdapter_HealthCheck_Failure(t *testing.T) {
|
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}))
|
|
defer srv.Close()
|
|
|
|
err := newTestAdapter(t, srv.URL).HealthCheck(context.Background())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "503")
|
|
}
|