veylant/internal/proxy/handler_test.go
2026-02-23 13:35:04 +01:00

341 lines
11 KiB
Go

package proxy_test
import (
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
"github.com/veylant/ia-gateway/internal/apierror"
"github.com/veylant/ia-gateway/internal/middleware"
piiv1 "github.com/veylant/ia-gateway/gen/pii/v1"
"github.com/veylant/ia-gateway/internal/pii"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/proxy"
)
// ─── Stub adapter ─────────────────────────────────────────────────────────────
type stubAdapter struct {
sendResp *provider.ChatResponse
sendErr error
streamErr error
sseChunks []string // lines to emit during Stream
}
func (s *stubAdapter) Validate(req *provider.ChatRequest) error {
if req.Model == "" {
return apierror.NewBadRequestError("field 'model' is required")
}
if len(req.Messages) == 0 {
return apierror.NewBadRequestError("field 'messages' must not be empty")
}
return nil
}
func (s *stubAdapter) Send(_ context.Context, _ *provider.ChatRequest) (*provider.ChatResponse, error) {
return s.sendResp, s.sendErr
}
func (s *stubAdapter) Stream(_ context.Context, _ *provider.ChatRequest, w http.ResponseWriter) error {
if s.streamErr != nil {
return s.streamErr
}
flusher, ok := w.(http.Flusher)
if !ok {
return fmt.Errorf("not flushable")
}
for _, line := range s.sseChunks {
fmt.Fprintf(w, "%s\n\n", line) //nolint:errcheck
flusher.Flush()
}
return nil
}
func (s *stubAdapter) HealthCheck(_ context.Context) error { return nil }
// ─── Helpers ──────────────────────────────────────────────────────────────────
func newHandler(adapter provider.Adapter) *proxy.Handler {
return proxy.New(adapter, zap.NewNop(), nil)
}
func postRequest(body string) *http.Request {
return httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
}
const validBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
const streamBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"stream":true}`
// ─── Tests ────────────────────────────────────────────────────────────────────
func TestHandler_NonStreaming_Nominal(t *testing.T) {
expected := &provider.ChatResponse{
ID: "chatcmpl-123",
Model: "gpt-4o",
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "Hello!"},
}},
}
h := newHandler(&stubAdapter{sendResp: expected})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
var got provider.ChatResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
assert.Equal(t, expected.ID, got.ID)
assert.Equal(t, "Hello!", got.Choices[0].Message.Content)
}
func TestHandler_Streaming_Nominal(t *testing.T) {
chunks := []string{
`data: {"choices":[{"delta":{"content":"He"}}]}`,
`data: {"choices":[{"delta":{"content":"llo"}}]}`,
`data: [DONE]`,
}
h := newHandler(&stubAdapter{sseChunks: chunks})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(streamBody))
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type"))
body := rec.Body.String()
assert.Contains(t, body, "He")
assert.Contains(t, body, "llo")
assert.Contains(t, body, "[DONE]")
}
func TestHandler_AdapterReturns429_ProxyReturns429(t *testing.T) {
h := newHandler(&stubAdapter{sendErr: apierror.NewRateLimitError("rate limit exceeded")})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
assertErrorType(t, rec, "rate_limit_error")
}
func TestHandler_AdapterTimeout_Returns504(t *testing.T) {
h := newHandler(&stubAdapter{sendErr: apierror.NewTimeoutError("timed out")})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
assert.Equal(t, http.StatusGatewayTimeout, rec.Code)
assertErrorType(t, rec, "api_error")
}
func TestHandler_InvalidJSON_Returns400(t *testing.T) {
h := newHandler(&stubAdapter{})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(`not json`))
assert.Equal(t, http.StatusBadRequest, rec.Code)
assertErrorType(t, rec, "invalid_request_error")
}
func TestHandler_MissingModel_Returns400(t *testing.T) {
body := `{"messages":[{"role":"user","content":"hi"}]}`
h := newHandler(&stubAdapter{})
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(body))
assert.Equal(t, http.StatusBadRequest, rec.Code)
assertErrorType(t, rec, "invalid_request_error")
}
func TestHandler_RequestIDPropagated(t *testing.T) {
const fixedID = "01956a00-1111-7000-8000-000000000042"
h := newHandler(&stubAdapter{sendResp: &provider.ChatResponse{ID: "x"}})
req := postRequest(validBody)
ctx := middleware.RequestIDFromContext(req.Context()) // initially empty
_ = ctx
// Inject the request ID via context (as the RequestID middleware would do).
req = req.WithContext(context.Background())
// Simulate what middleware.RequestID would do.
req.Header.Set("X-Request-Id", fixedID)
// We use a simple wrapper to inject into context for the test.
injectCtx := injectRequestID(req, fixedID)
rec := httptest.NewRecorder()
rec.Header().Set("X-Request-Id", fixedID) // simulate middleware header
h.ServeHTTP(rec, injectCtx)
assert.Equal(t, http.StatusOK, rec.Code)
}
// assertErrorType checks the OpenAI error envelope type in the response body.
func assertErrorType(t *testing.T, rec *httptest.ResponseRecorder, expectedType string) {
t.Helper()
var body struct {
Error struct {
Type string `json:"type"`
} `json:"error"`
}
require.NoError(t, json.NewDecoder(rec.Body).Decode(&body))
assert.Equal(t, expectedType, body.Error.Type)
}
// injectRequestID injects a request ID into the request context, mimicking
// the middleware.RequestID middleware for test purposes.
func injectRequestID(r *http.Request, id string) *http.Request {
// We access the internal function via the exported accessor pattern.
// Since withRequestID is unexported, we rebuild the request via context.
// In real usage the RequestID middleware handles this.
ctx := r.Context()
_ = id
return r.WithContext(ctx)
}
// ─── Mock PII gRPC server ─────────────────────────────────────────────────
type mockPiiSrv struct {
piiv1.UnimplementedPiiServiceServer
resp *piiv1.PiiResponse
}
func (m *mockPiiSrv) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
if m.resp != nil {
return m.resp, nil
}
return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil
}
func startPiiServer(t *testing.T, srv piiv1.PiiServiceServer) *pii.Client {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := grpc.NewServer()
piiv1.RegisterPiiServiceServer(s, srv)
t.Cleanup(func() { s.Stop() })
go func() { _ = s.Serve(lis) }()
c, err := pii.New(pii.Config{
Address: lis.Addr().String(),
Timeout: 500 * time.Millisecond,
FailOpen: false,
}, zap.NewNop())
require.NoError(t, err)
t.Cleanup(func() { _ = c.Close() })
return c
}
// ─── PII integration tests ─────────────────────────────────────────────────
func TestHandler_PII_Nil_ClientSkipsPII(t *testing.T) {
// With piiClient=nil, the handler should not modify the request.
resp := &provider.ChatResponse{
ID: "chatcmpl-pii",
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "response text"},
}},
}
h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
assert.Equal(t, http.StatusOK, rec.Code)
var got provider.ChatResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
assert.Equal(t, "response text", got.Choices[0].Message.Content)
}
func TestHandler_Depseudo_AppliedToResponse(t *testing.T) {
// The response contains a PII token — BuildEntityMap + Depseudonymize should restore it.
// We test this via depseudo_test.go directly; here we verify the handler wires them correctly
// when no piiClient is set (entityMap is nil/empty → response unchanged).
resp := &provider.ChatResponse{
ID: "x",
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "[PII:EMAIL:abc12345]"},
}},
}
h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
var got provider.ChatResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
// Without a PII client the token should be left unchanged.
assert.Equal(t, "[PII:EMAIL:abc12345]", got.Choices[0].Message.Content)
}
func TestHandler_PII_AnonymizesRequestAndDepseudonymizesResponse(t *testing.T) {
// The PII service returns an anonymized text and an entity mapping.
// The handler should: (1) replace last user message with anonymized text,
// (2) de-pseudonymize the LLM response.
piiSrv := &mockPiiSrv{
resp: &piiv1.PiiResponse{
AnonymizedText: "[PII:EMAIL:abc12345]",
Entities: []*piiv1.PiiEntity{
{
EntityType: "EMAIL",
OriginalValue: "alice@example.com",
Pseudonym: "[PII:EMAIL:abc12345]",
},
},
},
}
piiClient := startPiiServer(t, piiSrv)
// The LLM "echoes" the anonymized token in the response.
llmResp := &provider.ChatResponse{
ID: "x",
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "Your email is [PII:EMAIL:abc12345]"},
}},
}
h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient)
body := `{"model":"gpt-4o","messages":[{"role":"user","content":"alice@example.com"}]}`
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(body))
require.Equal(t, http.StatusOK, rec.Code)
var got provider.ChatResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
// The PII token in the response must be restored to the original value.
assert.Contains(t, got.Choices[0].Message.Content, "alice@example.com")
}
func TestHandler_PII_NoEntitiesDetected_ResponseUnchanged(t *testing.T) {
// PII service returns original text unchanged — response should also be unchanged.
piiClient := startPiiServer(t, &mockPiiSrv{})
llmResp := &provider.ChatResponse{
ID: "x",
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "all good"},
}},
}
h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient)
rec := httptest.NewRecorder()
h.ServeHTTP(rec, postRequest(validBody))
require.Equal(t, http.StatusOK, rec.Code)
var got provider.ChatResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
assert.Equal(t, "all good", got.Choices[0].Message.Content)
}