veylant/internal/proxy/handler.go
2026-02-23 13:35:04 +01:00

411 lines
13 KiB
Go

// Package proxy implements the HTTP handler for the /v1/chat/completions endpoint.
package proxy
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"strings"
"time"
"go.uber.org/zap"
"github.com/veylant/ia-gateway/internal/apierror"
"github.com/veylant/ia-gateway/internal/auditlog"
"github.com/veylant/ia-gateway/internal/billing"
"github.com/veylant/ia-gateway/internal/crypto"
"github.com/veylant/ia-gateway/internal/flags"
"github.com/veylant/ia-gateway/internal/middleware"
"github.com/veylant/ia-gateway/internal/pii"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/routing"
)
// Handler handles POST /v1/chat/completions.
// It dispatches to the underlying provider.Adapter for both non-streaming and
// streaming (SSE) requests, optionally running PII anonymization before the
// upstream call and de-pseudonymizing non-streaming responses.
type Handler struct {
adapter provider.Adapter
piiClient *pii.Client // nil means PII disabled
auditLogger auditlog.Logger // nil means audit logging disabled
encryptor *crypto.Encryptor // nil means encryption disabled
flagStore flags.FlagStore // nil means feature flags disabled
logger *zap.Logger
}
// New creates a Handler backed by adapter.
// Pass a non-nil piiClient to enable PII anonymization.
func New(adapter provider.Adapter, logger *zap.Logger, piiClient *pii.Client) *Handler {
return &Handler{adapter: adapter, piiClient: piiClient, logger: logger}
}
// NewWithAudit creates a Handler with audit logging and prompt encryption enabled.
// Either auditLogger or encryptor may be nil to disable those features independently.
func NewWithAudit(
adapter provider.Adapter,
logger *zap.Logger,
piiClient *pii.Client,
al auditlog.Logger,
enc *crypto.Encryptor,
) *Handler {
return &Handler{
adapter: adapter,
piiClient: piiClient,
auditLogger: al,
encryptor: enc,
logger: logger,
}
}
// WithFlagStore attaches a feature flag store to the handler (used for zero-retention, etc.).
func (h *Handler) WithFlagStore(fs flags.FlagStore) *Handler {
h.flagStore = fs
return h
}
// requestMeta groups per-request metadata to avoid long function signatures.
type requestMeta struct {
requestID string
tenantID string
userID string
userRole string
department string
sensitivity routing.Sensitivity
piiEntityCount int
startTime time.Time
}
// ServeHTTP implements http.Handler.
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
start := time.Now()
requestID := middleware.RequestIDFromContext(r.Context())
var req provider.ChatRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON body: "+err.Error()))
return
}
if err := h.adapter.Validate(&req); err != nil {
if apiErr, ok := err.(*apierror.APIError); ok {
apierror.WriteError(w, apiErr)
} else {
apierror.WriteError(w, apierror.NewBadRequestError(err.Error()))
}
return
}
// Collect request metadata for logging / audit.
meta := requestMeta{requestID: requestID, startTime: start}
logFields := []zap.Field{
zap.String("request_id", requestID),
zap.String("model", req.Model),
zap.Bool("stream", req.Stream),
}
if claims, ok := middleware.ClaimsFromContext(r.Context()); ok {
meta.userID = claims.UserID
meta.tenantID = claims.TenantID
meta.department = claims.Department
if len(claims.Roles) > 0 {
meta.userRole = claims.Roles[0]
}
logFields = append(logFields,
zap.String("user_id", meta.userID),
zap.String("tenant_id", meta.tenantID),
)
}
// Capture the original (pre-PII) prompt for hashing.
originalPrompt := extractLastUserMessage(&req)
// ── PII anonymization (request side) ────────────────────────────────────
var entityMap map[string]string
var anonymizedPrompt string
// Check pii_enabled flag (E11-07): defaults to true when not set.
piiEnabled := true
if h.flagStore != nil && meta.tenantID != "" {
if enabled, err := h.flagStore.IsEnabled(r.Context(), meta.tenantID, "pii_enabled"); err == nil {
// IsEnabled returns false both when the flag is disabled AND when it
// doesn't exist (no flag set). We treat absence as "enabled" to avoid
// breaking existing tenants who haven't set this flag. The migration
// seeds a global default of true, so after applying 000009 this is
// always correct. Before that migration, we fail-open (pii enabled).
piiEnabled = enabled
}
}
if h.piiClient != nil && piiEnabled {
promptText := originalPrompt
if promptText != "" {
// Check zero-retention flag: if enabled, the PII service will not persist
// pseudonymization mappings in Redis (E4-12).
zeroRetention := false
if h.flagStore != nil && meta.tenantID != "" {
zeroRetention, _ = h.flagStore.IsEnabled(r.Context(), meta.tenantID, "zero_retention")
}
result, err := h.piiClient.Detect(r.Context(), promptText, meta.tenantID, requestID, true, zeroRetention)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError("PII service error: "+err.Error()))
return
}
meta.piiEntityCount = len(result.Entities)
if len(result.Entities) > 0 {
score := routing.ScoreFromEntities(result.Entities)
meta.sensitivity = score
r = r.WithContext(routing.WithSensitivity(r.Context(), score))
}
if result.AnonymizedText != promptText {
replaceLastUserMessage(&req, result.AnonymizedText)
anonymizedPrompt = result.AnonymizedText
entityMap = BuildEntityMap(result.Entities)
h.logger.Debug("pii anonymized prompt",
zap.Int("entities", len(result.Entities)),
zap.Int64("pii_ms", result.ProcessingTimeMs),
)
}
}
}
if anonymizedPrompt == "" {
anonymizedPrompt = originalPrompt
}
if req.Stream {
h.handleStream(w, r, &req, meta, logFields, anonymizedPrompt, originalPrompt)
} else {
h.handleSend(w, r, &req, meta, logFields, entityMap, anonymizedPrompt, originalPrompt)
}
}
func (h *Handler) handleSend(
w http.ResponseWriter,
r *http.Request,
req *provider.ChatRequest,
meta requestMeta,
logFields []zap.Field,
entityMap map[string]string,
anonymizedPrompt, originalPrompt string,
) {
resp, err := h.adapter.Send(r.Context(), req)
latencyMs := int(time.Since(meta.startTime).Milliseconds())
if err != nil {
h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err)
if apiErr, ok := err.(*apierror.APIError); ok {
apierror.WriteError(w, apiErr)
} else {
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
}
h.logger.Error("proxy send error", append(logFields, zap.Error(err))...)
return
}
// De-pseudonymize the LLM response choices.
if len(entityMap) > 0 {
for i := range resp.Choices {
resp.Choices[i].Message.Content = Depseudonymize(resp.Choices[i].Message.Content, entityMap)
}
}
h.fireAuditEntry(meta, req, resp, originalPrompt, anonymizedPrompt, latencyMs, nil)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if err := json.NewEncoder(w).Encode(resp); err != nil {
h.logger.Error("encoding response", append(logFields, zap.Error(err))...)
}
h.logger.Info("proxy send ok",
append(logFields,
zap.Duration("latency", time.Since(meta.startTime)),
zap.Int("total_tokens", resp.Usage.TotalTokens),
)...,
)
}
func (h *Handler) handleStream(
w http.ResponseWriter,
r *http.Request,
req *provider.ChatRequest,
meta requestMeta,
logFields []zap.Field,
anonymizedPrompt, originalPrompt string,
) {
// Set SSE headers before the adapter starts writing.
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // disable Nginx proxy buffering
w.WriteHeader(http.StatusOK)
err := h.adapter.Stream(r.Context(), req, w)
latencyMs := int(time.Since(meta.startTime).Milliseconds())
// For streaming: no ChatResponse is available; fire audit with nil resp.
h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err)
if err != nil {
// Headers are already sent — we cannot change the status code.
h.logger.Error("proxy stream error", append(logFields, zap.Error(err))...)
return
}
h.logger.Info("proxy stream ok",
append(logFields, zap.Duration("latency", time.Since(meta.startTime)))...,
)
}
// fireAuditEntry builds and enqueues an AuditEntry. It is a no-op when auditLogger is nil.
func (h *Handler) fireAuditEntry(
meta requestMeta,
req *provider.ChatRequest,
resp *provider.ChatResponse,
originalPrompt, anonymizedPrompt string,
latencyMs int,
reqErr error,
) {
if h.auditLogger == nil {
return
}
entry := auditlog.AuditEntry{
RequestID: meta.requestID,
TenantID: meta.tenantID,
UserID: meta.userID,
Timestamp: meta.startTime,
ModelRequested: req.Model,
ModelUsed: req.Model,
Department: meta.department,
UserRole: meta.userRole,
PromptHash: sha256hex(originalPrompt),
SensitivityLevel: meta.sensitivity.String(),
LatencyMs: latencyMs,
PIIEntityCount: meta.piiEntityCount,
Stream: req.Stream,
Status: "ok",
}
if reqErr != nil {
entry.Status = "error"
entry.ErrorType = errorType(reqErr)
}
if resp != nil {
entry.ModelUsed = resp.Model
entry.Provider = resp.Provider
entry.ResponseHash = sha256hex(responseContent(resp))
entry.TokenInput = resp.Usage.PromptTokens
entry.TokenOutput = resp.Usage.CompletionTokens
entry.TokenTotal = resp.Usage.TotalTokens
entry.CostUSD = billing.CostUSD(resp.Provider, resp.Model, resp.Usage.TotalTokens)
} else {
// For streaming or errors, infer provider from model prefix.
entry.Provider = inferProvider(req.Model)
entry.TokenTotal = estimateTokens(req)
entry.CostUSD = billing.CostUSD(entry.Provider, req.Model, entry.TokenTotal)
}
// billing_enabled flag (E11-07): if explicitly disabled for this tenant,
// zero out the cost so the tenant is not charged in the audit log.
if h.flagStore != nil && meta.tenantID != "" {
if billingEnabled, err := h.flagStore.IsEnabled(context.Background(), meta.tenantID, "billing_enabled"); err == nil && !billingEnabled {
entry.CostUSD = 0
}
}
if h.encryptor != nil && anonymizedPrompt != "" {
encrypted, err := h.encryptor.Encrypt(anonymizedPrompt)
if err != nil {
h.logger.Warn("audit log: failed to encrypt prompt", zap.Error(err))
} else {
entry.PromptAnonymized = encrypted
}
}
h.auditLogger.Log(entry)
}
// ── helpers ──────────────────────────────────────────────────────────────────
// extractLastUserMessage returns the content of the last "user" message in req.
func extractLastUserMessage(req *provider.ChatRequest) string {
for i := len(req.Messages) - 1; i >= 0; i-- {
if strings.EqualFold(req.Messages[i].Role, "user") {
return req.Messages[i].Content
}
}
return ""
}
// replaceLastUserMessage replaces the content of the last user message with anonymized.
func replaceLastUserMessage(req *provider.ChatRequest, anonymized string) {
for i := len(req.Messages) - 1; i >= 0; i-- {
if strings.EqualFold(req.Messages[i].Role, "user") {
req.Messages[i].Content = anonymized
return
}
}
}
// sha256hex returns the lowercase hex SHA-256 hash of s.
func sha256hex(s string) string {
h := sha256.Sum256([]byte(s))
return hex.EncodeToString(h[:])
}
// responseContent returns the concatenated content of all choices in resp.
func responseContent(resp *provider.ChatResponse) string {
var sb strings.Builder
for _, c := range resp.Choices {
sb.WriteString(c.Message.Content)
}
return sb.String()
}
// errorType extracts a simple error type string for audit logging.
func errorType(err error) string {
if err == nil {
return ""
}
if apiErr, ok := err.(*apierror.APIError); ok {
return apiErr.Type
}
return "upstream_error"
}
// inferProvider maps a model name prefix to a provider name, mirroring the
// router's static prefix rules. Used for streaming where no ChatResponse is available.
func inferProvider(model string) string {
prefixes := []struct{ prefix, provider string }{
{"gpt-", "openai"},
{"o1-", "openai"},
{"o3-", "openai"},
{"claude-", "anthropic"},
{"mistral-", "mistral"},
{"mixtral-", "mistral"},
{"llama", "ollama"},
{"phi", "ollama"},
{"qwen", "ollama"},
}
for _, p := range prefixes {
if strings.HasPrefix(model, p.prefix) {
return p.provider
}
}
return "openai" // default fallback
}
// estimateTokens returns a rough token count (1 token ≈ 4 chars).
func estimateTokens(req *provider.ChatRequest) int {
total := 0
for _, m := range req.Messages {
total += len(m.Content) / 4
}
return total
}