// Package proxy implements the HTTP handler for the /v1/chat/completions endpoint. package proxy import ( "context" "crypto/sha256" "encoding/hex" "encoding/json" "net/http" "strings" "time" "go.uber.org/zap" "github.com/veylant/ia-gateway/internal/apierror" "github.com/veylant/ia-gateway/internal/auditlog" "github.com/veylant/ia-gateway/internal/billing" "github.com/veylant/ia-gateway/internal/crypto" "github.com/veylant/ia-gateway/internal/flags" "github.com/veylant/ia-gateway/internal/middleware" "github.com/veylant/ia-gateway/internal/pii" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/routing" ) // Handler handles POST /v1/chat/completions. // It dispatches to the underlying provider.Adapter for both non-streaming and // streaming (SSE) requests, optionally running PII anonymization before the // upstream call and de-pseudonymizing non-streaming responses. type Handler struct { adapter provider.Adapter piiClient *pii.Client // nil means PII disabled auditLogger auditlog.Logger // nil means audit logging disabled encryptor *crypto.Encryptor // nil means encryption disabled flagStore flags.FlagStore // nil means feature flags disabled logger *zap.Logger } // New creates a Handler backed by adapter. // Pass a non-nil piiClient to enable PII anonymization. func New(adapter provider.Adapter, logger *zap.Logger, piiClient *pii.Client) *Handler { return &Handler{adapter: adapter, piiClient: piiClient, logger: logger} } // NewWithAudit creates a Handler with audit logging and prompt encryption enabled. // Either auditLogger or encryptor may be nil to disable those features independently. func NewWithAudit( adapter provider.Adapter, logger *zap.Logger, piiClient *pii.Client, al auditlog.Logger, enc *crypto.Encryptor, ) *Handler { return &Handler{ adapter: adapter, piiClient: piiClient, auditLogger: al, encryptor: enc, logger: logger, } } // WithFlagStore attaches a feature flag store to the handler (used for zero-retention, etc.). func (h *Handler) WithFlagStore(fs flags.FlagStore) *Handler { h.flagStore = fs return h } // requestMeta groups per-request metadata to avoid long function signatures. type requestMeta struct { requestID string tenantID string userID string userRole string department string sensitivity routing.Sensitivity piiEntityCount int startTime time.Time } // ServeHTTP implements http.Handler. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { start := time.Now() requestID := middleware.RequestIDFromContext(r.Context()) var req provider.ChatRequest if err := json.NewDecoder(r.Body).Decode(&req); err != nil { apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON body: "+err.Error())) return } if err := h.adapter.Validate(&req); err != nil { if apiErr, ok := err.(*apierror.APIError); ok { apierror.WriteError(w, apiErr) } else { apierror.WriteError(w, apierror.NewBadRequestError(err.Error())) } return } // Collect request metadata for logging / audit. meta := requestMeta{requestID: requestID, startTime: start} logFields := []zap.Field{ zap.String("request_id", requestID), zap.String("model", req.Model), zap.Bool("stream", req.Stream), } if claims, ok := middleware.ClaimsFromContext(r.Context()); ok { meta.userID = claims.UserID meta.tenantID = claims.TenantID meta.department = claims.Department if len(claims.Roles) > 0 { meta.userRole = claims.Roles[0] } logFields = append(logFields, zap.String("user_id", meta.userID), zap.String("tenant_id", meta.tenantID), ) } // Capture the original (pre-PII) prompt for hashing. originalPrompt := extractLastUserMessage(&req) // ── PII anonymization (request side) ──────────────────────────────────── var entityMap map[string]string var anonymizedPrompt string // Check pii_enabled flag (E11-07): defaults to true when not set. piiEnabled := true if h.flagStore != nil && meta.tenantID != "" { if enabled, err := h.flagStore.IsEnabled(r.Context(), meta.tenantID, "pii_enabled"); err == nil { // IsEnabled returns false both when the flag is disabled AND when it // doesn't exist (no flag set). We treat absence as "enabled" to avoid // breaking existing tenants who haven't set this flag. The migration // seeds a global default of true, so after applying 000009 this is // always correct. Before that migration, we fail-open (pii enabled). piiEnabled = enabled } } if h.piiClient != nil && piiEnabled { promptText := originalPrompt if promptText != "" { // Check zero-retention flag: if enabled, the PII service will not persist // pseudonymization mappings in Redis (E4-12). zeroRetention := false if h.flagStore != nil && meta.tenantID != "" { zeroRetention, _ = h.flagStore.IsEnabled(r.Context(), meta.tenantID, "zero_retention") } result, err := h.piiClient.Detect(r.Context(), promptText, meta.tenantID, requestID, true, zeroRetention) if err != nil { apierror.WriteError(w, apierror.NewUpstreamError("PII service error: "+err.Error())) return } meta.piiEntityCount = len(result.Entities) if len(result.Entities) > 0 { score := routing.ScoreFromEntities(result.Entities) meta.sensitivity = score r = r.WithContext(routing.WithSensitivity(r.Context(), score)) } if result.AnonymizedText != promptText { replaceLastUserMessage(&req, result.AnonymizedText) anonymizedPrompt = result.AnonymizedText entityMap = BuildEntityMap(result.Entities) h.logger.Debug("pii anonymized prompt", zap.Int("entities", len(result.Entities)), zap.Int64("pii_ms", result.ProcessingTimeMs), ) } } } if anonymizedPrompt == "" { anonymizedPrompt = originalPrompt } if req.Stream { h.handleStream(w, r, &req, meta, logFields, anonymizedPrompt, originalPrompt) } else { h.handleSend(w, r, &req, meta, logFields, entityMap, anonymizedPrompt, originalPrompt) } } func (h *Handler) handleSend( w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, meta requestMeta, logFields []zap.Field, entityMap map[string]string, anonymizedPrompt, originalPrompt string, ) { resp, err := h.adapter.Send(r.Context(), req) latencyMs := int(time.Since(meta.startTime).Milliseconds()) if err != nil { h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err) if apiErr, ok := err.(*apierror.APIError); ok { apierror.WriteError(w, apiErr) } else { apierror.WriteError(w, apierror.NewUpstreamError(err.Error())) } h.logger.Error("proxy send error", append(logFields, zap.Error(err))...) return } // De-pseudonymize the LLM response choices. if len(entityMap) > 0 { for i := range resp.Choices { resp.Choices[i].Message.Content = Depseudonymize(resp.Choices[i].Message.Content, entityMap) } } h.fireAuditEntry(meta, req, resp, originalPrompt, anonymizedPrompt, latencyMs, nil) w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) if err := json.NewEncoder(w).Encode(resp); err != nil { h.logger.Error("encoding response", append(logFields, zap.Error(err))...) } h.logger.Info("proxy send ok", append(logFields, zap.Duration("latency", time.Since(meta.startTime)), zap.Int("total_tokens", resp.Usage.TotalTokens), )..., ) } func (h *Handler) handleStream( w http.ResponseWriter, r *http.Request, req *provider.ChatRequest, meta requestMeta, logFields []zap.Field, anonymizedPrompt, originalPrompt string, ) { // Set SSE headers before the adapter starts writing. w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") // disable Nginx proxy buffering w.WriteHeader(http.StatusOK) err := h.adapter.Stream(r.Context(), req, w) latencyMs := int(time.Since(meta.startTime).Milliseconds()) // For streaming: no ChatResponse is available; fire audit with nil resp. h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err) if err != nil { // Headers are already sent — we cannot change the status code. h.logger.Error("proxy stream error", append(logFields, zap.Error(err))...) return } h.logger.Info("proxy stream ok", append(logFields, zap.Duration("latency", time.Since(meta.startTime)))..., ) } // fireAuditEntry builds and enqueues an AuditEntry. It is a no-op when auditLogger is nil. func (h *Handler) fireAuditEntry( meta requestMeta, req *provider.ChatRequest, resp *provider.ChatResponse, originalPrompt, anonymizedPrompt string, latencyMs int, reqErr error, ) { if h.auditLogger == nil { return } entry := auditlog.AuditEntry{ RequestID: meta.requestID, TenantID: meta.tenantID, UserID: meta.userID, Timestamp: meta.startTime, ModelRequested: req.Model, ModelUsed: req.Model, Department: meta.department, UserRole: meta.userRole, PromptHash: sha256hex(originalPrompt), SensitivityLevel: meta.sensitivity.String(), LatencyMs: latencyMs, PIIEntityCount: meta.piiEntityCount, Stream: req.Stream, Status: "ok", } if reqErr != nil { entry.Status = "error" entry.ErrorType = errorType(reqErr) } if resp != nil { entry.ModelUsed = resp.Model entry.Provider = resp.Provider entry.ResponseHash = sha256hex(responseContent(resp)) entry.TokenInput = resp.Usage.PromptTokens entry.TokenOutput = resp.Usage.CompletionTokens entry.TokenTotal = resp.Usage.TotalTokens entry.CostUSD = billing.CostUSD(resp.Provider, resp.Model, resp.Usage.TotalTokens) } else { // For streaming or errors, infer provider from model prefix. entry.Provider = inferProvider(req.Model) entry.TokenTotal = estimateTokens(req) entry.CostUSD = billing.CostUSD(entry.Provider, req.Model, entry.TokenTotal) } // billing_enabled flag (E11-07): if explicitly disabled for this tenant, // zero out the cost so the tenant is not charged in the audit log. if h.flagStore != nil && meta.tenantID != "" { if billingEnabled, err := h.flagStore.IsEnabled(context.Background(), meta.tenantID, "billing_enabled"); err == nil && !billingEnabled { entry.CostUSD = 0 } } if h.encryptor != nil && anonymizedPrompt != "" { encrypted, err := h.encryptor.Encrypt(anonymizedPrompt) if err != nil { h.logger.Warn("audit log: failed to encrypt prompt", zap.Error(err)) } else { entry.PromptAnonymized = encrypted } } h.auditLogger.Log(entry) } // ── helpers ────────────────────────────────────────────────────────────────── // extractLastUserMessage returns the content of the last "user" message in req. func extractLastUserMessage(req *provider.ChatRequest) string { for i := len(req.Messages) - 1; i >= 0; i-- { if strings.EqualFold(req.Messages[i].Role, "user") { return req.Messages[i].Content } } return "" } // replaceLastUserMessage replaces the content of the last user message with anonymized. func replaceLastUserMessage(req *provider.ChatRequest, anonymized string) { for i := len(req.Messages) - 1; i >= 0; i-- { if strings.EqualFold(req.Messages[i].Role, "user") { req.Messages[i].Content = anonymized return } } } // sha256hex returns the lowercase hex SHA-256 hash of s. func sha256hex(s string) string { h := sha256.Sum256([]byte(s)) return hex.EncodeToString(h[:]) } // responseContent returns the concatenated content of all choices in resp. func responseContent(resp *provider.ChatResponse) string { var sb strings.Builder for _, c := range resp.Choices { sb.WriteString(c.Message.Content) } return sb.String() } // errorType extracts a simple error type string for audit logging. func errorType(err error) string { if err == nil { return "" } if apiErr, ok := err.(*apierror.APIError); ok { return apiErr.Type } return "upstream_error" } // inferProvider maps a model name prefix to a provider name, mirroring the // router's static prefix rules. Used for streaming where no ChatResponse is available. func inferProvider(model string) string { prefixes := []struct{ prefix, provider string }{ {"gpt-", "openai"}, {"o1-", "openai"}, {"o3-", "openai"}, {"claude-", "anthropic"}, {"mistral-", "mistral"}, {"mixtral-", "mistral"}, {"llama", "ollama"}, {"phi", "ollama"}, {"qwen", "ollama"}, } for _, p := range prefixes { if strings.HasPrefix(model, p.prefix) { return p.provider } } return "openai" // default fallback } // estimateTokens returns a rough token count (1 token ≈ 4 chars). func estimateTokens(req *provider.ChatRequest) int { total := 0 for _, m := range req.Messages { total += len(m.Content) / 4 } return total }