411 lines
13 KiB
Go
411 lines
13 KiB
Go
// Package proxy implements the HTTP handler for the /v1/chat/completions endpoint.
|
|
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
"github.com/veylant/ia-gateway/internal/auditlog"
|
|
"github.com/veylant/ia-gateway/internal/billing"
|
|
"github.com/veylant/ia-gateway/internal/crypto"
|
|
"github.com/veylant/ia-gateway/internal/flags"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/pii"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/routing"
|
|
)
|
|
|
|
// Handler handles POST /v1/chat/completions.
|
|
// It dispatches to the underlying provider.Adapter for both non-streaming and
|
|
// streaming (SSE) requests, optionally running PII anonymization before the
|
|
// upstream call and de-pseudonymizing non-streaming responses.
|
|
type Handler struct {
|
|
adapter provider.Adapter
|
|
piiClient *pii.Client // nil means PII disabled
|
|
auditLogger auditlog.Logger // nil means audit logging disabled
|
|
encryptor *crypto.Encryptor // nil means encryption disabled
|
|
flagStore flags.FlagStore // nil means feature flags disabled
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// New creates a Handler backed by adapter.
|
|
// Pass a non-nil piiClient to enable PII anonymization.
|
|
func New(adapter provider.Adapter, logger *zap.Logger, piiClient *pii.Client) *Handler {
|
|
return &Handler{adapter: adapter, piiClient: piiClient, logger: logger}
|
|
}
|
|
|
|
// NewWithAudit creates a Handler with audit logging and prompt encryption enabled.
|
|
// Either auditLogger or encryptor may be nil to disable those features independently.
|
|
func NewWithAudit(
|
|
adapter provider.Adapter,
|
|
logger *zap.Logger,
|
|
piiClient *pii.Client,
|
|
al auditlog.Logger,
|
|
enc *crypto.Encryptor,
|
|
) *Handler {
|
|
return &Handler{
|
|
adapter: adapter,
|
|
piiClient: piiClient,
|
|
auditLogger: al,
|
|
encryptor: enc,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// WithFlagStore attaches a feature flag store to the handler (used for zero-retention, etc.).
|
|
func (h *Handler) WithFlagStore(fs flags.FlagStore) *Handler {
|
|
h.flagStore = fs
|
|
return h
|
|
}
|
|
|
|
// requestMeta groups per-request metadata to avoid long function signatures.
|
|
type requestMeta struct {
|
|
requestID string
|
|
tenantID string
|
|
userID string
|
|
userRole string
|
|
department string
|
|
sensitivity routing.Sensitivity
|
|
piiEntityCount int
|
|
startTime time.Time
|
|
}
|
|
|
|
// ServeHTTP implements http.Handler.
|
|
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
requestID := middleware.RequestIDFromContext(r.Context())
|
|
|
|
var req provider.ChatRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON body: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
if err := h.adapter.Validate(&req); err != nil {
|
|
if apiErr, ok := err.(*apierror.APIError); ok {
|
|
apierror.WriteError(w, apiErr)
|
|
} else {
|
|
apierror.WriteError(w, apierror.NewBadRequestError(err.Error()))
|
|
}
|
|
return
|
|
}
|
|
|
|
// Collect request metadata for logging / audit.
|
|
meta := requestMeta{requestID: requestID, startTime: start}
|
|
logFields := []zap.Field{
|
|
zap.String("request_id", requestID),
|
|
zap.String("model", req.Model),
|
|
zap.Bool("stream", req.Stream),
|
|
}
|
|
if claims, ok := middleware.ClaimsFromContext(r.Context()); ok {
|
|
meta.userID = claims.UserID
|
|
meta.tenantID = claims.TenantID
|
|
meta.department = claims.Department
|
|
if len(claims.Roles) > 0 {
|
|
meta.userRole = claims.Roles[0]
|
|
}
|
|
logFields = append(logFields,
|
|
zap.String("user_id", meta.userID),
|
|
zap.String("tenant_id", meta.tenantID),
|
|
)
|
|
}
|
|
|
|
// Capture the original (pre-PII) prompt for hashing.
|
|
originalPrompt := extractLastUserMessage(&req)
|
|
|
|
// ── PII anonymization (request side) ────────────────────────────────────
|
|
var entityMap map[string]string
|
|
var anonymizedPrompt string
|
|
|
|
// Check pii_enabled flag (E11-07): defaults to true when not set.
|
|
piiEnabled := true
|
|
if h.flagStore != nil && meta.tenantID != "" {
|
|
if enabled, err := h.flagStore.IsEnabled(r.Context(), meta.tenantID, "pii_enabled"); err == nil {
|
|
// IsEnabled returns false both when the flag is disabled AND when it
|
|
// doesn't exist (no flag set). We treat absence as "enabled" to avoid
|
|
// breaking existing tenants who haven't set this flag. The migration
|
|
// seeds a global default of true, so after applying 000009 this is
|
|
// always correct. Before that migration, we fail-open (pii enabled).
|
|
piiEnabled = enabled
|
|
}
|
|
}
|
|
|
|
if h.piiClient != nil && piiEnabled {
|
|
promptText := originalPrompt
|
|
if promptText != "" {
|
|
// Check zero-retention flag: if enabled, the PII service will not persist
|
|
// pseudonymization mappings in Redis (E4-12).
|
|
zeroRetention := false
|
|
if h.flagStore != nil && meta.tenantID != "" {
|
|
zeroRetention, _ = h.flagStore.IsEnabled(r.Context(), meta.tenantID, "zero_retention")
|
|
}
|
|
result, err := h.piiClient.Detect(r.Context(), promptText, meta.tenantID, requestID, true, zeroRetention)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError("PII service error: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
meta.piiEntityCount = len(result.Entities)
|
|
if len(result.Entities) > 0 {
|
|
score := routing.ScoreFromEntities(result.Entities)
|
|
meta.sensitivity = score
|
|
r = r.WithContext(routing.WithSensitivity(r.Context(), score))
|
|
}
|
|
|
|
if result.AnonymizedText != promptText {
|
|
replaceLastUserMessage(&req, result.AnonymizedText)
|
|
anonymizedPrompt = result.AnonymizedText
|
|
entityMap = BuildEntityMap(result.Entities)
|
|
h.logger.Debug("pii anonymized prompt",
|
|
zap.Int("entities", len(result.Entities)),
|
|
zap.Int64("pii_ms", result.ProcessingTimeMs),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
if anonymizedPrompt == "" {
|
|
anonymizedPrompt = originalPrompt
|
|
}
|
|
|
|
if req.Stream {
|
|
h.handleStream(w, r, &req, meta, logFields, anonymizedPrompt, originalPrompt)
|
|
} else {
|
|
h.handleSend(w, r, &req, meta, logFields, entityMap, anonymizedPrompt, originalPrompt)
|
|
}
|
|
}
|
|
|
|
func (h *Handler) handleSend(
|
|
w http.ResponseWriter,
|
|
r *http.Request,
|
|
req *provider.ChatRequest,
|
|
meta requestMeta,
|
|
logFields []zap.Field,
|
|
entityMap map[string]string,
|
|
anonymizedPrompt, originalPrompt string,
|
|
) {
|
|
resp, err := h.adapter.Send(r.Context(), req)
|
|
latencyMs := int(time.Since(meta.startTime).Milliseconds())
|
|
|
|
if err != nil {
|
|
h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err)
|
|
if apiErr, ok := err.(*apierror.APIError); ok {
|
|
apierror.WriteError(w, apiErr)
|
|
} else {
|
|
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
|
|
}
|
|
h.logger.Error("proxy send error", append(logFields, zap.Error(err))...)
|
|
return
|
|
}
|
|
|
|
// De-pseudonymize the LLM response choices.
|
|
if len(entityMap) > 0 {
|
|
for i := range resp.Choices {
|
|
resp.Choices[i].Message.Content = Depseudonymize(resp.Choices[i].Message.Content, entityMap)
|
|
}
|
|
}
|
|
|
|
h.fireAuditEntry(meta, req, resp, originalPrompt, anonymizedPrompt, latencyMs, nil)
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(http.StatusOK)
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
h.logger.Error("encoding response", append(logFields, zap.Error(err))...)
|
|
}
|
|
|
|
h.logger.Info("proxy send ok",
|
|
append(logFields,
|
|
zap.Duration("latency", time.Since(meta.startTime)),
|
|
zap.Int("total_tokens", resp.Usage.TotalTokens),
|
|
)...,
|
|
)
|
|
}
|
|
|
|
func (h *Handler) handleStream(
|
|
w http.ResponseWriter,
|
|
r *http.Request,
|
|
req *provider.ChatRequest,
|
|
meta requestMeta,
|
|
logFields []zap.Field,
|
|
anonymizedPrompt, originalPrompt string,
|
|
) {
|
|
// Set SSE headers before the adapter starts writing.
|
|
w.Header().Set("Content-Type", "text/event-stream")
|
|
w.Header().Set("Cache-Control", "no-cache")
|
|
w.Header().Set("Connection", "keep-alive")
|
|
w.Header().Set("X-Accel-Buffering", "no") // disable Nginx proxy buffering
|
|
w.WriteHeader(http.StatusOK)
|
|
|
|
err := h.adapter.Stream(r.Context(), req, w)
|
|
latencyMs := int(time.Since(meta.startTime).Milliseconds())
|
|
|
|
// For streaming: no ChatResponse is available; fire audit with nil resp.
|
|
h.fireAuditEntry(meta, req, nil, originalPrompt, anonymizedPrompt, latencyMs, err)
|
|
|
|
if err != nil {
|
|
// Headers are already sent — we cannot change the status code.
|
|
h.logger.Error("proxy stream error", append(logFields, zap.Error(err))...)
|
|
return
|
|
}
|
|
|
|
h.logger.Info("proxy stream ok",
|
|
append(logFields, zap.Duration("latency", time.Since(meta.startTime)))...,
|
|
)
|
|
}
|
|
|
|
// fireAuditEntry builds and enqueues an AuditEntry. It is a no-op when auditLogger is nil.
|
|
func (h *Handler) fireAuditEntry(
|
|
meta requestMeta,
|
|
req *provider.ChatRequest,
|
|
resp *provider.ChatResponse,
|
|
originalPrompt, anonymizedPrompt string,
|
|
latencyMs int,
|
|
reqErr error,
|
|
) {
|
|
if h.auditLogger == nil {
|
|
return
|
|
}
|
|
|
|
entry := auditlog.AuditEntry{
|
|
RequestID: meta.requestID,
|
|
TenantID: meta.tenantID,
|
|
UserID: meta.userID,
|
|
Timestamp: meta.startTime,
|
|
ModelRequested: req.Model,
|
|
ModelUsed: req.Model,
|
|
Department: meta.department,
|
|
UserRole: meta.userRole,
|
|
PromptHash: sha256hex(originalPrompt),
|
|
SensitivityLevel: meta.sensitivity.String(),
|
|
LatencyMs: latencyMs,
|
|
PIIEntityCount: meta.piiEntityCount,
|
|
Stream: req.Stream,
|
|
Status: "ok",
|
|
}
|
|
|
|
if reqErr != nil {
|
|
entry.Status = "error"
|
|
entry.ErrorType = errorType(reqErr)
|
|
}
|
|
|
|
if resp != nil {
|
|
entry.ModelUsed = resp.Model
|
|
entry.Provider = resp.Provider
|
|
entry.ResponseHash = sha256hex(responseContent(resp))
|
|
entry.TokenInput = resp.Usage.PromptTokens
|
|
entry.TokenOutput = resp.Usage.CompletionTokens
|
|
entry.TokenTotal = resp.Usage.TotalTokens
|
|
entry.CostUSD = billing.CostUSD(resp.Provider, resp.Model, resp.Usage.TotalTokens)
|
|
} else {
|
|
// For streaming or errors, infer provider from model prefix.
|
|
entry.Provider = inferProvider(req.Model)
|
|
entry.TokenTotal = estimateTokens(req)
|
|
entry.CostUSD = billing.CostUSD(entry.Provider, req.Model, entry.TokenTotal)
|
|
}
|
|
|
|
// billing_enabled flag (E11-07): if explicitly disabled for this tenant,
|
|
// zero out the cost so the tenant is not charged in the audit log.
|
|
if h.flagStore != nil && meta.tenantID != "" {
|
|
if billingEnabled, err := h.flagStore.IsEnabled(context.Background(), meta.tenantID, "billing_enabled"); err == nil && !billingEnabled {
|
|
entry.CostUSD = 0
|
|
}
|
|
}
|
|
|
|
if h.encryptor != nil && anonymizedPrompt != "" {
|
|
encrypted, err := h.encryptor.Encrypt(anonymizedPrompt)
|
|
if err != nil {
|
|
h.logger.Warn("audit log: failed to encrypt prompt", zap.Error(err))
|
|
} else {
|
|
entry.PromptAnonymized = encrypted
|
|
}
|
|
}
|
|
|
|
h.auditLogger.Log(entry)
|
|
}
|
|
|
|
// ── helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
// extractLastUserMessage returns the content of the last "user" message in req.
|
|
func extractLastUserMessage(req *provider.ChatRequest) string {
|
|
for i := len(req.Messages) - 1; i >= 0; i-- {
|
|
if strings.EqualFold(req.Messages[i].Role, "user") {
|
|
return req.Messages[i].Content
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
// replaceLastUserMessage replaces the content of the last user message with anonymized.
|
|
func replaceLastUserMessage(req *provider.ChatRequest, anonymized string) {
|
|
for i := len(req.Messages) - 1; i >= 0; i-- {
|
|
if strings.EqualFold(req.Messages[i].Role, "user") {
|
|
req.Messages[i].Content = anonymized
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// sha256hex returns the lowercase hex SHA-256 hash of s.
|
|
func sha256hex(s string) string {
|
|
h := sha256.Sum256([]byte(s))
|
|
return hex.EncodeToString(h[:])
|
|
}
|
|
|
|
// responseContent returns the concatenated content of all choices in resp.
|
|
func responseContent(resp *provider.ChatResponse) string {
|
|
var sb strings.Builder
|
|
for _, c := range resp.Choices {
|
|
sb.WriteString(c.Message.Content)
|
|
}
|
|
return sb.String()
|
|
}
|
|
|
|
// errorType extracts a simple error type string for audit logging.
|
|
func errorType(err error) string {
|
|
if err == nil {
|
|
return ""
|
|
}
|
|
if apiErr, ok := err.(*apierror.APIError); ok {
|
|
return apiErr.Type
|
|
}
|
|
return "upstream_error"
|
|
}
|
|
|
|
// inferProvider maps a model name prefix to a provider name, mirroring the
|
|
// router's static prefix rules. Used for streaming where no ChatResponse is available.
|
|
func inferProvider(model string) string {
|
|
prefixes := []struct{ prefix, provider string }{
|
|
{"gpt-", "openai"},
|
|
{"o1-", "openai"},
|
|
{"o3-", "openai"},
|
|
{"claude-", "anthropic"},
|
|
{"mistral-", "mistral"},
|
|
{"mixtral-", "mistral"},
|
|
{"llama", "ollama"},
|
|
{"phi", "ollama"},
|
|
{"qwen", "ollama"},
|
|
}
|
|
for _, p := range prefixes {
|
|
if strings.HasPrefix(model, p.prefix) {
|
|
return p.provider
|
|
}
|
|
}
|
|
return "openai" // default fallback
|
|
}
|
|
|
|
// estimateTokens returns a rough token count (1 token ≈ 4 chars).
|
|
func estimateTokens(req *provider.ChatRequest) int {
|
|
total := 0
|
|
for _, m := range req.Messages {
|
|
total += len(m.Content) / 4
|
|
}
|
|
return total
|
|
}
|