veylant/cmd/proxy/main.go
2026-03-10 09:20:38 +01:00

481 lines
17 KiB
Go

package main
import (
"context"
"database/sql"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-chi/chi/v5"
chimiddleware "github.com/go-chi/chi/v5/middleware"
"github.com/prometheus/client_golang/prometheus/promhttp"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
_ "github.com/jackc/pgx/v5/stdlib" // register pgx driver
"github.com/veylant/ia-gateway/internal/admin"
"github.com/veylant/ia-gateway/internal/auth"
"github.com/veylant/ia-gateway/internal/auditlog"
"github.com/veylant/ia-gateway/internal/circuitbreaker"
"github.com/veylant/ia-gateway/internal/compliance"
"github.com/veylant/ia-gateway/internal/config"
"github.com/veylant/ia-gateway/internal/crypto"
"github.com/veylant/ia-gateway/internal/flags"
"github.com/veylant/ia-gateway/internal/health"
"github.com/veylant/ia-gateway/internal/metrics"
"github.com/veylant/ia-gateway/internal/middleware"
"github.com/veylant/ia-gateway/internal/pii"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/provider/anthropic"
"github.com/veylant/ia-gateway/internal/provider/azure"
"github.com/veylant/ia-gateway/internal/provider/mistral"
"github.com/veylant/ia-gateway/internal/provider/ollama"
"github.com/veylant/ia-gateway/internal/provider/openai"
"github.com/veylant/ia-gateway/internal/proxy"
"github.com/veylant/ia-gateway/internal/ratelimit"
"github.com/veylant/ia-gateway/internal/router"
"github.com/veylant/ia-gateway/internal/routing"
)
func main() {
cfg, err := config.Load()
if err != nil {
fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err)
os.Exit(1)
}
logger := buildLogger(cfg.Log.Level, cfg.Log.Format)
defer logger.Sync() //nolint:errcheck
// ── Local JWT verifier (email/password auth — replaces Keycloak OIDC) ────
ctx := context.Background()
jwtVerifier := auth.NewLocalJWTVerifier(cfg.Auth.JWTSecret)
logger.Info("local JWT verifier initialised")
// ── LLM provider adapters ─────────────────────────────────────────────────
adapters := map[string]provider.Adapter{}
adapters["openai"] = openai.New(openai.Config{
APIKey: cfg.Providers.OpenAI.APIKey,
BaseURL: cfg.Providers.OpenAI.BaseURL,
TimeoutSeconds: cfg.Providers.OpenAI.TimeoutSeconds,
MaxConns: cfg.Providers.OpenAI.MaxConns,
})
if cfg.Providers.Anthropic.APIKey != "" {
adapters["anthropic"] = anthropic.New(anthropic.Config{
APIKey: cfg.Providers.Anthropic.APIKey,
BaseURL: cfg.Providers.Anthropic.BaseURL,
Version: cfg.Providers.Anthropic.Version,
TimeoutSeconds: cfg.Providers.Anthropic.TimeoutSeconds,
MaxConns: cfg.Providers.Anthropic.MaxConns,
})
logger.Info("Anthropic adapter enabled")
}
if cfg.Providers.Azure.ResourceName != "" && cfg.Providers.Azure.APIKey != "" {
adapters["azure"] = azure.New(azure.Config{
APIKey: cfg.Providers.Azure.APIKey,
ResourceName: cfg.Providers.Azure.ResourceName,
DeploymentID: cfg.Providers.Azure.DeploymentID,
APIVersion: cfg.Providers.Azure.APIVersion,
TimeoutSeconds: cfg.Providers.Azure.TimeoutSeconds,
MaxConns: cfg.Providers.Azure.MaxConns,
})
logger.Info("Azure OpenAI adapter enabled",
zap.String("resource", cfg.Providers.Azure.ResourceName),
zap.String("deployment", cfg.Providers.Azure.DeploymentID),
)
}
if cfg.Providers.Mistral.APIKey != "" {
adapters["mistral"] = mistral.New(mistral.Config{
APIKey: cfg.Providers.Mistral.APIKey,
BaseURL: cfg.Providers.Mistral.BaseURL,
TimeoutSeconds: cfg.Providers.Mistral.TimeoutSeconds,
MaxConns: cfg.Providers.Mistral.MaxConns,
})
logger.Info("Mistral adapter enabled")
}
adapters["ollama"] = ollama.New(ollama.Config{
BaseURL: cfg.Providers.Ollama.BaseURL,
TimeoutSeconds: cfg.Providers.Ollama.TimeoutSeconds,
MaxConns: cfg.Providers.Ollama.MaxConns,
})
logger.Info("Ollama adapter enabled", zap.String("base_url", cfg.Providers.Ollama.BaseURL))
// ── Database (PostgreSQL via pgx) ─────────────────────────────────────────
var db *sql.DB
if cfg.Database.URL != "" {
var dbErr error
db, dbErr = sql.Open("pgx", cfg.Database.URL)
if dbErr != nil {
logger.Fatal("failed to open database", zap.Error(dbErr))
}
db.SetMaxOpenConns(cfg.Database.MaxOpenConns)
db.SetMaxIdleConns(cfg.Database.MaxIdleConns)
if pingErr := db.PingContext(ctx); pingErr != nil {
if cfg.Server.Env == "development" {
logger.Warn("database unavailable — routing engine disabled", zap.Error(pingErr))
db = nil
} else {
logger.Fatal("database ping failed", zap.Error(pingErr))
}
} else {
logger.Info("database connected", zap.String("url", cfg.Database.URL))
}
}
// ── Routing engine ────────────────────────────────────────────────────────
var routingEngine *routing.Engine
if db != nil {
ttl := time.Duration(cfg.Routing.CacheTTLSeconds) * time.Second
if ttl <= 0 {
ttl = 30 * time.Second
}
pgStore := routing.NewPgStore(db, logger)
routingEngine = routing.New(pgStore, ttl, logger)
routingEngine.Start()
logger.Info("routing engine started", zap.Duration("cache_ttl", ttl))
}
// ── Circuit breaker (E2-09) ───────────────────────────────────────────────
cb := circuitbreaker.New(5, 60*time.Second)
logger.Info("circuit breaker initialised", zap.Int("threshold", 5), zap.Duration("open_ttl", 60*time.Second))
// ── Provider router (RBAC + model dispatch + optional engine) ─────────────
providerRouter := router.NewWithEngineAndBreaker(adapters, &cfg.RBAC, routingEngine, cb, logger)
logger.Info("provider router initialised",
zap.Int("adapter_count", len(adapters)),
zap.Strings("user_allowed_models", cfg.RBAC.UserAllowedModels),
zap.Bool("routing_engine", routingEngine != nil),
)
// ── PII client (optional) ─────────────────────────────────────────────────
var piiClient *pii.Client
if cfg.PII.Enabled {
pc, piiErr := pii.New(pii.Config{
Address: cfg.PII.ServiceAddr,
Timeout: time.Duration(cfg.PII.TimeoutMs) * time.Millisecond,
FailOpen: cfg.PII.FailOpen,
}, logger)
if piiErr != nil {
logger.Warn("PII client init failed — PII disabled", zap.Error(piiErr))
} else {
piiClient = pc
defer pc.Close() //nolint:errcheck
logger.Info("PII client connected", zap.String("addr", cfg.PII.ServiceAddr))
}
}
// ── AES-256-GCM encryptor (optional) ─────────────────────────────────────
var encryptor *crypto.Encryptor
if cfg.Crypto.AESKeyBase64 != "" {
enc, encErr := crypto.NewEncryptor(cfg.Crypto.AESKeyBase64)
if encErr != nil {
logger.Warn("crypto encryptor init failed — prompt encryption disabled", zap.Error(encErr))
} else {
encryptor = enc
logger.Info("AES-256-GCM encryptor enabled")
}
} else {
logger.Warn("VEYLANT_CRYPTO_AES_KEY_BASE64 not set — prompt encryption disabled")
}
// ── Load provider configs from DB (supplements static config-based adapters) ─
if db != nil {
loadProvidersFromDB(ctx, db, providerRouter, encryptor, logger)
}
// ── ClickHouse audit logger (optional) ────────────────────────────────────
var auditLogger auditlog.Logger
if cfg.ClickHouse.DSN != "" {
chLogger, chErr := auditlog.NewClickHouseLogger(
cfg.ClickHouse.DSN,
cfg.ClickHouse.MaxConns,
cfg.ClickHouse.DialTimeoutSec,
logger,
)
if chErr != nil {
if cfg.Server.Env == "development" {
logger.Warn("ClickHouse unavailable — audit logging disabled", zap.Error(chErr))
} else {
logger.Fatal("ClickHouse init failed", zap.Error(chErr))
}
} else {
// Apply DDL idempotently.
ddlPath := "migrations/clickhouse/000001_audit_logs.sql"
if ddlErr := chLogger.ApplyDDL(ddlPath); ddlErr != nil {
logger.Warn("ClickHouse DDL apply failed — audit logging disabled", zap.Error(ddlErr))
} else {
chLogger.Start()
defer chLogger.Stop()
auditLogger = chLogger
logger.Info("ClickHouse audit logger started", zap.String("dsn", cfg.ClickHouse.DSN))
}
}
} else {
logger.Warn("clickhouse.dsn not set — audit logging disabled")
}
// ── Feature flag store (E4-12 zero-retention + future flags + E11-07) ──────
var flagStore flags.FlagStore
if db != nil {
flagStore = flags.NewPgFlagStore(db, logger)
logger.Info("feature flag store: PostgreSQL")
} else {
flagStore = flags.NewMemFlagStore()
logger.Warn("feature flag store: in-memory (no database)")
}
// Wire flag store into the provider router so it can check routing_enabled (E11-07).
providerRouter.WithFlagStore(flagStore)
// ── Proxy handler ─────────────────────────────────────────────────────────
proxyHandler := proxy.NewWithAudit(providerRouter, logger, piiClient, auditLogger, encryptor).
WithFlagStore(flagStore)
// ── Rate limiter (E10-09) ─────────────────────────────────────────────────
rateLimiter := ratelimit.New(ratelimit.RateLimitConfig{
RequestsPerMin: cfg.RateLimit.DefaultTenantRPM,
BurstSize: cfg.RateLimit.DefaultTenantBurst,
UserRPM: cfg.RateLimit.DefaultUserRPM,
UserBurst: cfg.RateLimit.DefaultUserBurst,
IsEnabled: true,
}, logger)
// Load per-tenant overrides from DB (best-effort; missing DB is graceful).
if db != nil {
rlStore := ratelimit.NewStore(db, logger)
if overrides, err := rlStore.List(ctx); err == nil {
for _, cfg := range overrides {
rateLimiter.SetConfig(cfg)
}
logger.Info("rate limit overrides loaded", zap.Int("count", len(overrides)))
} else {
logger.Warn("failed to load rate limit overrides", zap.Error(err))
}
}
logger.Info("rate limiter initialised",
zap.Int("default_tenant_rpm", cfg.RateLimit.DefaultTenantRPM),
zap.Int("default_user_rpm", cfg.RateLimit.DefaultUserRPM),
)
// ── HTTP router ───────────────────────────────────────────────────────────
r := chi.NewRouter()
r.Use(middleware.SecurityHeaders(cfg.Server.Env))
r.Use(middleware.RequestID)
r.Use(chimiddleware.RealIP)
r.Use(chimiddleware.Recoverer)
if cfg.Metrics.Enabled {
r.Use(metrics.Middleware("openai"))
}
r.Get("/", health.LandingHandler)
r.Get("/healthz", health.Handler)
// OpenAPI documentation (E11-02).
r.Get("/docs", health.DocsHTMLHandler)
r.Get("/docs/openapi.yaml", health.DocsYAMLHandler)
// Public PII playground — no JWT required (E8-15).
r.Get("/playground", health.PlaygroundHandler)
r.Post("/playground/analyze", health.PlaygroundAnalyzeHandler(piiClient, logger))
if cfg.Metrics.Enabled {
r.Get(cfg.Metrics.Path, promhttp.Handler().ServeHTTP)
}
loginHandler := auth.NewLoginHandler(db, cfg.Auth.JWTSecret, cfg.Auth.JWTTTLHours, logger)
r.Route("/v1", func(r chi.Router) {
r.Use(middleware.CORS(cfg.Server.AllowedOrigins))
// Public — CORS applied, no auth required.
r.Post("/auth/login", loginHandler.ServeHTTP)
// Protected — JWT auth + tenant rate limit.
r.Group(func(r chi.Router) {
r.Use(middleware.Auth(jwtVerifier))
r.Use(middleware.RateLimit(rateLimiter))
r.Post("/chat/completions", proxyHandler.ServeHTTP)
// PII analyze endpoint for Playground (E8-11, Sprint 8).
piiAnalyzeHandler := pii.NewAnalyzeHandler(piiClient, logger)
r.Post("/pii/analyze", piiAnalyzeHandler.ServeHTTP)
// Admin API — routing policies + audit logs (Sprint 5 + Sprint 6)
// + user management + provider status (Sprint 8).
if routingEngine != nil {
var adminHandler *admin.Handler
if auditLogger != nil {
adminHandler = admin.NewWithAudit(
routing.NewPgStore(db, logger),
routingEngine.Cache(),
auditLogger,
logger,
)
} else {
adminHandler = admin.New(
routing.NewPgStore(db, logger),
routingEngine.Cache(),
logger,
)
}
// Wire db, router, rate limiter, feature flags, and encryptor.
adminHandler.WithDB(db).WithRouter(providerRouter).WithRateLimiter(rateLimiter).WithFlagStore(flagStore).WithEncryptor(encryptor)
r.Route("/admin", adminHandler.Routes)
}
// Compliance module — GDPR Art. 30 registry + AI Act classification + PDF reports (Sprint 9).
if db != nil {
compStore := compliance.NewPgStore(db, logger)
compHandler := compliance.New(compStore, logger).
WithAudit(auditLogger).
WithDB(db).
WithTenantName(cfg.Server.TenantName)
r.Route("/admin/compliance", compHandler.Routes)
logger.Info("compliance module started")
}
})
})
// ── HTTP server ───────────────────────────────────────────────────────────
addr := fmt.Sprintf(":%d", cfg.Server.Port)
srv := &http.Server{
Addr: addr,
Handler: r,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT)
go func() {
logger.Info("Veylant IA proxy started",
zap.String("addr", addr),
zap.String("env", cfg.Server.Env),
zap.Bool("metrics", cfg.Metrics.Enabled),
zap.String("auth", "local-jwt"),
zap.Bool("audit_logging", auditLogger != nil),
zap.Bool("encryption", encryptor != nil),
)
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Fatal("server error", zap.Error(err))
}
}()
<-quit
logger.Info("shutdown signal received, draining connections...")
if routingEngine != nil {
routingEngine.Stop()
}
timeout := time.Duration(cfg.Server.ShutdownTimeout) * time.Second
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
logger.Error("graceful shutdown failed", zap.Error(err))
os.Exit(1)
}
logger.Info("server stopped cleanly")
}
// loadProvidersFromDB reads active provider_configs from the database and
// registers/updates adapters on the router. Called once at startup so DB-saved
// configs take precedence over (or extend) the static config.yaml ones.
func loadProvidersFromDB(ctx context.Context, db *sql.DB, providerRouter *router.Router, enc *crypto.Encryptor, logger *zap.Logger) {
rows, err := db.QueryContext(ctx,
`SELECT provider, api_key_enc, base_url, resource_name, deployment_id,
api_version, timeout_sec, max_conns, model_patterns
FROM provider_configs WHERE is_active = TRUE`)
if err != nil {
logger.Warn("failed to load provider configs from DB", zap.Error(err))
return
}
defer rows.Close()
count := 0
for rows.Next() {
var (
prov, apiKeyEnc, baseURL, resourceName, deploymentID, apiVersion string
timeoutSec, maxConns int
patternsRaw string
)
if err := rows.Scan(&prov, &apiKeyEnc, &baseURL, &resourceName,
&deploymentID, &apiVersion, &timeoutSec, &maxConns, &patternsRaw); err != nil {
logger.Warn("failed to scan provider config row", zap.Error(err))
continue
}
// Decrypt API key
apiKey := apiKeyEnc
if enc != nil && apiKeyEnc != "" {
if dec, decErr := enc.Decrypt(apiKeyEnc); decErr == nil {
apiKey = dec
}
}
cfg := &admin.ProviderConfig{
Provider: prov,
BaseURL: baseURL,
ResourceName: resourceName,
DeploymentID: deploymentID,
APIVersion: apiVersion,
TimeoutSec: timeoutSec,
MaxConns: maxConns,
}
adapter := admin.BuildProviderAdapter(cfg, apiKey)
if adapter != nil {
providerRouter.UpdateAdapter(prov, adapter)
count++
}
}
if count > 0 {
logger.Info("provider configs loaded from DB", zap.Int("count", count))
}
}
func buildLogger(level, format string) *zap.Logger {
lvl := zap.InfoLevel
if err := lvl.UnmarshalText([]byte(level)); err != nil {
lvl = zap.InfoLevel
}
encoderCfg := zap.NewProductionEncoderConfig()
encoderCfg.TimeKey = "timestamp"
encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
encoding := "json"
if format == "console" {
encoding = "console"
}
zapCfg := zap.Config{
Level: zap.NewAtomicLevelAt(lvl),
Development: false,
Encoding: encoding,
EncoderConfig: encoderCfg,
OutputPaths: []string{"stdout"},
ErrorOutputPaths: []string{"stderr"},
}
logger, err := zapCfg.Build()
if err != nil {
panic(fmt.Sprintf("failed to build logger: %v", err))
}
return logger
}