package proxy_test import ( "context" "encoding/json" "fmt" "net" "net/http" "net/http/httptest" "strings" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "google.golang.org/grpc" "github.com/veylant/ia-gateway/internal/apierror" "github.com/veylant/ia-gateway/internal/middleware" piiv1 "github.com/veylant/ia-gateway/gen/pii/v1" "github.com/veylant/ia-gateway/internal/pii" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/proxy" ) // ─── Stub adapter ───────────────────────────────────────────────────────────── type stubAdapter struct { sendResp *provider.ChatResponse sendErr error streamErr error sseChunks []string // lines to emit during Stream } func (s *stubAdapter) Validate(req *provider.ChatRequest) error { if req.Model == "" { return apierror.NewBadRequestError("field 'model' is required") } if len(req.Messages) == 0 { return apierror.NewBadRequestError("field 'messages' must not be empty") } return nil } func (s *stubAdapter) Send(_ context.Context, _ *provider.ChatRequest) (*provider.ChatResponse, error) { return s.sendResp, s.sendErr } func (s *stubAdapter) Stream(_ context.Context, _ *provider.ChatRequest, w http.ResponseWriter) error { if s.streamErr != nil { return s.streamErr } flusher, ok := w.(http.Flusher) if !ok { return fmt.Errorf("not flushable") } for _, line := range s.sseChunks { fmt.Fprintf(w, "%s\n\n", line) //nolint:errcheck flusher.Flush() } return nil } func (s *stubAdapter) HealthCheck(_ context.Context) error { return nil } // ─── Helpers ────────────────────────────────────────────────────────────────── func newHandler(adapter provider.Adapter) *proxy.Handler { return proxy.New(adapter, zap.NewNop(), nil) } func postRequest(body string) *http.Request { return httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body)) } const validBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}` const streamBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"stream":true}` // ─── Tests ──────────────────────────────────────────────────────────────────── func TestHandler_NonStreaming_Nominal(t *testing.T) { expected := &provider.ChatResponse{ ID: "chatcmpl-123", Model: "gpt-4o", Choices: []provider.Choice{{ Message: provider.Message{Role: "assistant", Content: "Hello!"}, }}, } h := newHandler(&stubAdapter{sendResp: expected}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "application/json", rec.Header().Get("Content-Type")) var got provider.ChatResponse require.NoError(t, json.NewDecoder(rec.Body).Decode(&got)) assert.Equal(t, expected.ID, got.ID) assert.Equal(t, "Hello!", got.Choices[0].Message.Content) } func TestHandler_Streaming_Nominal(t *testing.T) { chunks := []string{ `data: {"choices":[{"delta":{"content":"He"}}]}`, `data: {"choices":[{"delta":{"content":"llo"}}]}`, `data: [DONE]`, } h := newHandler(&stubAdapter{sseChunks: chunks}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(streamBody)) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type")) body := rec.Body.String() assert.Contains(t, body, "He") assert.Contains(t, body, "llo") assert.Contains(t, body, "[DONE]") } func TestHandler_AdapterReturns429_ProxyReturns429(t *testing.T) { h := newHandler(&stubAdapter{sendErr: apierror.NewRateLimitError("rate limit exceeded")}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) assert.Equal(t, http.StatusTooManyRequests, rec.Code) assertErrorType(t, rec, "rate_limit_error") } func TestHandler_AdapterTimeout_Returns504(t *testing.T) { h := newHandler(&stubAdapter{sendErr: apierror.NewTimeoutError("timed out")}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) assert.Equal(t, http.StatusGatewayTimeout, rec.Code) assertErrorType(t, rec, "api_error") } func TestHandler_InvalidJSON_Returns400(t *testing.T) { h := newHandler(&stubAdapter{}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(`not json`)) assert.Equal(t, http.StatusBadRequest, rec.Code) assertErrorType(t, rec, "invalid_request_error") } func TestHandler_MissingModel_Returns400(t *testing.T) { body := `{"messages":[{"role":"user","content":"hi"}]}` h := newHandler(&stubAdapter{}) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(body)) assert.Equal(t, http.StatusBadRequest, rec.Code) assertErrorType(t, rec, "invalid_request_error") } func TestHandler_RequestIDPropagated(t *testing.T) { const fixedID = "01956a00-1111-7000-8000-000000000042" h := newHandler(&stubAdapter{sendResp: &provider.ChatResponse{ID: "x"}}) req := postRequest(validBody) ctx := middleware.RequestIDFromContext(req.Context()) // initially empty _ = ctx // Inject the request ID via context (as the RequestID middleware would do). req = req.WithContext(context.Background()) // Simulate what middleware.RequestID would do. req.Header.Set("X-Request-Id", fixedID) // We use a simple wrapper to inject into context for the test. injectCtx := injectRequestID(req, fixedID) rec := httptest.NewRecorder() rec.Header().Set("X-Request-Id", fixedID) // simulate middleware header h.ServeHTTP(rec, injectCtx) assert.Equal(t, http.StatusOK, rec.Code) } // assertErrorType checks the OpenAI error envelope type in the response body. func assertErrorType(t *testing.T, rec *httptest.ResponseRecorder, expectedType string) { t.Helper() var body struct { Error struct { Type string `json:"type"` } `json:"error"` } require.NoError(t, json.NewDecoder(rec.Body).Decode(&body)) assert.Equal(t, expectedType, body.Error.Type) } // injectRequestID injects a request ID into the request context, mimicking // the middleware.RequestID middleware for test purposes. func injectRequestID(r *http.Request, id string) *http.Request { // We access the internal function via the exported accessor pattern. // Since withRequestID is unexported, we rebuild the request via context. // In real usage the RequestID middleware handles this. ctx := r.Context() _ = id return r.WithContext(ctx) } // ─── Mock PII gRPC server ───────────────────────────────────────────────── type mockPiiSrv struct { piiv1.UnimplementedPiiServiceServer resp *piiv1.PiiResponse } func (m *mockPiiSrv) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) { if m.resp != nil { return m.resp, nil } return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil } func startPiiServer(t *testing.T, srv piiv1.PiiServiceServer) *pii.Client { t.Helper() lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) s := grpc.NewServer() piiv1.RegisterPiiServiceServer(s, srv) t.Cleanup(func() { s.Stop() }) go func() { _ = s.Serve(lis) }() c, err := pii.New(pii.Config{ Address: lis.Addr().String(), Timeout: 500 * time.Millisecond, FailOpen: false, }, zap.NewNop()) require.NoError(t, err) t.Cleanup(func() { _ = c.Close() }) return c } // ─── PII integration tests ───────────────────────────────────────────────── func TestHandler_PII_Nil_ClientSkipsPII(t *testing.T) { // With piiClient=nil, the handler should not modify the request. resp := &provider.ChatResponse{ ID: "chatcmpl-pii", Choices: []provider.Choice{{ Message: provider.Message{Role: "assistant", Content: "response text"}, }}, } h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) assert.Equal(t, http.StatusOK, rec.Code) var got provider.ChatResponse require.NoError(t, json.NewDecoder(rec.Body).Decode(&got)) assert.Equal(t, "response text", got.Choices[0].Message.Content) } func TestHandler_Depseudo_AppliedToResponse(t *testing.T) { // The response contains a PII token — BuildEntityMap + Depseudonymize should restore it. // We test this via depseudo_test.go directly; here we verify the handler wires them correctly // when no piiClient is set (entityMap is nil/empty → response unchanged). resp := &provider.ChatResponse{ ID: "x", Choices: []provider.Choice{{ Message: provider.Message{Role: "assistant", Content: "[PII:EMAIL:abc12345]"}, }}, } h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) var got provider.ChatResponse require.NoError(t, json.NewDecoder(rec.Body).Decode(&got)) // Without a PII client the token should be left unchanged. assert.Equal(t, "[PII:EMAIL:abc12345]", got.Choices[0].Message.Content) } func TestHandler_PII_AnonymizesRequestAndDepseudonymizesResponse(t *testing.T) { // The PII service returns an anonymized text and an entity mapping. // The handler should: (1) replace last user message with anonymized text, // (2) de-pseudonymize the LLM response. piiSrv := &mockPiiSrv{ resp: &piiv1.PiiResponse{ AnonymizedText: "[PII:EMAIL:abc12345]", Entities: []*piiv1.PiiEntity{ { EntityType: "EMAIL", OriginalValue: "alice@example.com", Pseudonym: "[PII:EMAIL:abc12345]", }, }, }, } piiClient := startPiiServer(t, piiSrv) // The LLM "echoes" the anonymized token in the response. llmResp := &provider.ChatResponse{ ID: "x", Choices: []provider.Choice{{ Message: provider.Message{Role: "assistant", Content: "Your email is [PII:EMAIL:abc12345]"}, }}, } h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient) body := `{"model":"gpt-4o","messages":[{"role":"user","content":"alice@example.com"}]}` rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(body)) require.Equal(t, http.StatusOK, rec.Code) var got provider.ChatResponse require.NoError(t, json.NewDecoder(rec.Body).Decode(&got)) // The PII token in the response must be restored to the original value. assert.Contains(t, got.Choices[0].Message.Content, "alice@example.com") } func TestHandler_PII_NoEntitiesDetected_ResponseUnchanged(t *testing.T) { // PII service returns original text unchanged — response should also be unchanged. piiClient := startPiiServer(t, &mockPiiSrv{}) llmResp := &provider.ChatResponse{ ID: "x", Choices: []provider.Choice{{ Message: provider.Message{Role: "assistant", Content: "all good"}, }}, } h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient) rec := httptest.NewRecorder() h.ServeHTTP(rec, postRequest(validBody)) require.Equal(t, http.StatusOK, rec.Code) var got provider.ChatResponse require.NoError(t, json.NewDecoder(rec.Body).Decode(&got)) assert.Equal(t, "all good", got.Choices[0].Message.Content) }