341 lines
11 KiB
Go
341 lines
11 KiB
Go
package proxy_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/grpc"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
piiv1 "github.com/veylant/ia-gateway/gen/pii/v1"
|
|
"github.com/veylant/ia-gateway/internal/pii"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/proxy"
|
|
)
|
|
|
|
// ─── Stub adapter ─────────────────────────────────────────────────────────────
|
|
|
|
type stubAdapter struct {
|
|
sendResp *provider.ChatResponse
|
|
sendErr error
|
|
streamErr error
|
|
sseChunks []string // lines to emit during Stream
|
|
}
|
|
|
|
func (s *stubAdapter) Validate(req *provider.ChatRequest) error {
|
|
if req.Model == "" {
|
|
return apierror.NewBadRequestError("field 'model' is required")
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
return apierror.NewBadRequestError("field 'messages' must not be empty")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *stubAdapter) Send(_ context.Context, _ *provider.ChatRequest) (*provider.ChatResponse, error) {
|
|
return s.sendResp, s.sendErr
|
|
}
|
|
|
|
func (s *stubAdapter) Stream(_ context.Context, _ *provider.ChatRequest, w http.ResponseWriter) error {
|
|
if s.streamErr != nil {
|
|
return s.streamErr
|
|
}
|
|
flusher, ok := w.(http.Flusher)
|
|
if !ok {
|
|
return fmt.Errorf("not flushable")
|
|
}
|
|
for _, line := range s.sseChunks {
|
|
fmt.Fprintf(w, "%s\n\n", line) //nolint:errcheck
|
|
flusher.Flush()
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *stubAdapter) HealthCheck(_ context.Context) error { return nil }
|
|
|
|
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
func newHandler(adapter provider.Adapter) *proxy.Handler {
|
|
return proxy.New(adapter, zap.NewNop(), nil)
|
|
}
|
|
|
|
func postRequest(body string) *http.Request {
|
|
return httptest.NewRequest(http.MethodPost, "/v1/chat/completions", strings.NewReader(body))
|
|
}
|
|
|
|
const validBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}]}`
|
|
const streamBody = `{"model":"gpt-4o","messages":[{"role":"user","content":"hi"}],"stream":true}`
|
|
|
|
// ─── Tests ────────────────────────────────────────────────────────────────────
|
|
|
|
func TestHandler_NonStreaming_Nominal(t *testing.T) {
|
|
expected := &provider.ChatResponse{
|
|
ID: "chatcmpl-123",
|
|
Model: "gpt-4o",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "Hello!"},
|
|
}},
|
|
}
|
|
h := newHandler(&stubAdapter{sendResp: expected})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "application/json", rec.Header().Get("Content-Type"))
|
|
|
|
var got provider.ChatResponse
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
|
|
assert.Equal(t, expected.ID, got.ID)
|
|
assert.Equal(t, "Hello!", got.Choices[0].Message.Content)
|
|
}
|
|
|
|
func TestHandler_Streaming_Nominal(t *testing.T) {
|
|
chunks := []string{
|
|
`data: {"choices":[{"delta":{"content":"He"}}]}`,
|
|
`data: {"choices":[{"delta":{"content":"llo"}}]}`,
|
|
`data: [DONE]`,
|
|
}
|
|
h := newHandler(&stubAdapter{sseChunks: chunks})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(streamBody))
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "text/event-stream", rec.Header().Get("Content-Type"))
|
|
body := rec.Body.String()
|
|
assert.Contains(t, body, "He")
|
|
assert.Contains(t, body, "llo")
|
|
assert.Contains(t, body, "[DONE]")
|
|
}
|
|
|
|
func TestHandler_AdapterReturns429_ProxyReturns429(t *testing.T) {
|
|
h := newHandler(&stubAdapter{sendErr: apierror.NewRateLimitError("rate limit exceeded")})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
assert.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
assertErrorType(t, rec, "rate_limit_error")
|
|
}
|
|
|
|
func TestHandler_AdapterTimeout_Returns504(t *testing.T) {
|
|
h := newHandler(&stubAdapter{sendErr: apierror.NewTimeoutError("timed out")})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
assert.Equal(t, http.StatusGatewayTimeout, rec.Code)
|
|
assertErrorType(t, rec, "api_error")
|
|
}
|
|
|
|
func TestHandler_InvalidJSON_Returns400(t *testing.T) {
|
|
h := newHandler(&stubAdapter{})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(`not json`))
|
|
|
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
|
assertErrorType(t, rec, "invalid_request_error")
|
|
}
|
|
|
|
func TestHandler_MissingModel_Returns400(t *testing.T) {
|
|
body := `{"messages":[{"role":"user","content":"hi"}]}`
|
|
h := newHandler(&stubAdapter{})
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(body))
|
|
|
|
assert.Equal(t, http.StatusBadRequest, rec.Code)
|
|
assertErrorType(t, rec, "invalid_request_error")
|
|
}
|
|
|
|
func TestHandler_RequestIDPropagated(t *testing.T) {
|
|
const fixedID = "01956a00-1111-7000-8000-000000000042"
|
|
h := newHandler(&stubAdapter{sendResp: &provider.ChatResponse{ID: "x"}})
|
|
|
|
req := postRequest(validBody)
|
|
ctx := middleware.RequestIDFromContext(req.Context()) // initially empty
|
|
_ = ctx
|
|
// Inject the request ID via context (as the RequestID middleware would do).
|
|
req = req.WithContext(context.Background())
|
|
// Simulate what middleware.RequestID would do.
|
|
req.Header.Set("X-Request-Id", fixedID)
|
|
// We use a simple wrapper to inject into context for the test.
|
|
injectCtx := injectRequestID(req, fixedID)
|
|
|
|
rec := httptest.NewRecorder()
|
|
rec.Header().Set("X-Request-Id", fixedID) // simulate middleware header
|
|
h.ServeHTTP(rec, injectCtx)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
}
|
|
|
|
// assertErrorType checks the OpenAI error envelope type in the response body.
|
|
func assertErrorType(t *testing.T, rec *httptest.ResponseRecorder, expectedType string) {
|
|
t.Helper()
|
|
var body struct {
|
|
Error struct {
|
|
Type string `json:"type"`
|
|
} `json:"error"`
|
|
}
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&body))
|
|
assert.Equal(t, expectedType, body.Error.Type)
|
|
}
|
|
|
|
// injectRequestID injects a request ID into the request context, mimicking
|
|
// the middleware.RequestID middleware for test purposes.
|
|
func injectRequestID(r *http.Request, id string) *http.Request {
|
|
// We access the internal function via the exported accessor pattern.
|
|
// Since withRequestID is unexported, we rebuild the request via context.
|
|
// In real usage the RequestID middleware handles this.
|
|
ctx := r.Context()
|
|
_ = id
|
|
return r.WithContext(ctx)
|
|
}
|
|
|
|
// ─── Mock PII gRPC server ─────────────────────────────────────────────────
|
|
|
|
type mockPiiSrv struct {
|
|
piiv1.UnimplementedPiiServiceServer
|
|
resp *piiv1.PiiResponse
|
|
}
|
|
|
|
func (m *mockPiiSrv) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
|
|
if m.resp != nil {
|
|
return m.resp, nil
|
|
}
|
|
return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil
|
|
}
|
|
|
|
func startPiiServer(t *testing.T, srv piiv1.PiiServiceServer) *pii.Client {
|
|
t.Helper()
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
s := grpc.NewServer()
|
|
piiv1.RegisterPiiServiceServer(s, srv)
|
|
t.Cleanup(func() { s.Stop() })
|
|
go func() { _ = s.Serve(lis) }()
|
|
|
|
c, err := pii.New(pii.Config{
|
|
Address: lis.Addr().String(),
|
|
Timeout: 500 * time.Millisecond,
|
|
FailOpen: false,
|
|
}, zap.NewNop())
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { _ = c.Close() })
|
|
return c
|
|
}
|
|
|
|
// ─── PII integration tests ─────────────────────────────────────────────────
|
|
|
|
func TestHandler_PII_Nil_ClientSkipsPII(t *testing.T) {
|
|
// With piiClient=nil, the handler should not modify the request.
|
|
resp := &provider.ChatResponse{
|
|
ID: "chatcmpl-pii",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "response text"},
|
|
}},
|
|
}
|
|
h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil)
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
var got provider.ChatResponse
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
|
|
assert.Equal(t, "response text", got.Choices[0].Message.Content)
|
|
}
|
|
|
|
func TestHandler_Depseudo_AppliedToResponse(t *testing.T) {
|
|
// The response contains a PII token — BuildEntityMap + Depseudonymize should restore it.
|
|
// We test this via depseudo_test.go directly; here we verify the handler wires them correctly
|
|
// when no piiClient is set (entityMap is nil/empty → response unchanged).
|
|
resp := &provider.ChatResponse{
|
|
ID: "x",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "[PII:EMAIL:abc12345]"},
|
|
}},
|
|
}
|
|
h := proxy.New(&stubAdapter{sendResp: resp}, zap.NewNop(), nil)
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
var got provider.ChatResponse
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
|
|
// Without a PII client the token should be left unchanged.
|
|
assert.Equal(t, "[PII:EMAIL:abc12345]", got.Choices[0].Message.Content)
|
|
}
|
|
|
|
func TestHandler_PII_AnonymizesRequestAndDepseudonymizesResponse(t *testing.T) {
|
|
// The PII service returns an anonymized text and an entity mapping.
|
|
// The handler should: (1) replace last user message with anonymized text,
|
|
// (2) de-pseudonymize the LLM response.
|
|
piiSrv := &mockPiiSrv{
|
|
resp: &piiv1.PiiResponse{
|
|
AnonymizedText: "[PII:EMAIL:abc12345]",
|
|
Entities: []*piiv1.PiiEntity{
|
|
{
|
|
EntityType: "EMAIL",
|
|
OriginalValue: "alice@example.com",
|
|
Pseudonym: "[PII:EMAIL:abc12345]",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
piiClient := startPiiServer(t, piiSrv)
|
|
|
|
// The LLM "echoes" the anonymized token in the response.
|
|
llmResp := &provider.ChatResponse{
|
|
ID: "x",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "Your email is [PII:EMAIL:abc12345]"},
|
|
}},
|
|
}
|
|
h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient)
|
|
|
|
body := `{"model":"gpt-4o","messages":[{"role":"user","content":"alice@example.com"}]}`
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(body))
|
|
|
|
require.Equal(t, http.StatusOK, rec.Code)
|
|
var got provider.ChatResponse
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
|
|
// The PII token in the response must be restored to the original value.
|
|
assert.Contains(t, got.Choices[0].Message.Content, "alice@example.com")
|
|
}
|
|
|
|
func TestHandler_PII_NoEntitiesDetected_ResponseUnchanged(t *testing.T) {
|
|
// PII service returns original text unchanged — response should also be unchanged.
|
|
piiClient := startPiiServer(t, &mockPiiSrv{})
|
|
|
|
llmResp := &provider.ChatResponse{
|
|
ID: "x",
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "all good"},
|
|
}},
|
|
}
|
|
h := proxy.New(&stubAdapter{sendResp: llmResp}, zap.NewNop(), piiClient)
|
|
|
|
rec := httptest.NewRecorder()
|
|
h.ServeHTTP(rec, postRequest(validBody))
|
|
|
|
require.Equal(t, http.StatusOK, rec.Code)
|
|
var got provider.ChatResponse
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&got))
|
|
assert.Equal(t, "all good", got.Choices[0].Message.Content)
|
|
}
|