package main import ( "context" "database/sql" "errors" "fmt" "net/http" "os" "os/signal" "syscall" "time" "github.com/go-chi/chi/v5" chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/prometheus/client_golang/prometheus/promhttp" "go.uber.org/zap" "go.uber.org/zap/zapcore" _ "github.com/jackc/pgx/v5/stdlib" // register pgx driver "github.com/veylant/ia-gateway/internal/admin" "github.com/veylant/ia-gateway/internal/auth" "github.com/veylant/ia-gateway/internal/auditlog" "github.com/veylant/ia-gateway/internal/circuitbreaker" "github.com/veylant/ia-gateway/internal/compliance" "github.com/veylant/ia-gateway/internal/config" "github.com/veylant/ia-gateway/internal/crypto" "github.com/veylant/ia-gateway/internal/flags" "github.com/veylant/ia-gateway/internal/health" "github.com/veylant/ia-gateway/internal/metrics" "github.com/veylant/ia-gateway/internal/middleware" "github.com/veylant/ia-gateway/internal/pii" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/provider/anthropic" "github.com/veylant/ia-gateway/internal/provider/azure" "github.com/veylant/ia-gateway/internal/provider/mistral" "github.com/veylant/ia-gateway/internal/provider/ollama" "github.com/veylant/ia-gateway/internal/provider/openai" "github.com/veylant/ia-gateway/internal/proxy" "github.com/veylant/ia-gateway/internal/ratelimit" "github.com/veylant/ia-gateway/internal/router" "github.com/veylant/ia-gateway/internal/routing" ) func main() { cfg, err := config.Load() if err != nil { fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err) os.Exit(1) } logger := buildLogger(cfg.Log.Level, cfg.Log.Format) defer logger.Sync() //nolint:errcheck // ── Local JWT verifier (email/password auth — replaces Keycloak OIDC) ──── ctx := context.Background() jwtVerifier := auth.NewLocalJWTVerifier(cfg.Auth.JWTSecret) logger.Info("local JWT verifier initialised") // ── LLM provider adapters ───────────────────────────────────────────────── adapters := map[string]provider.Adapter{} adapters["openai"] = openai.New(openai.Config{ APIKey: cfg.Providers.OpenAI.APIKey, BaseURL: cfg.Providers.OpenAI.BaseURL, TimeoutSeconds: cfg.Providers.OpenAI.TimeoutSeconds, MaxConns: cfg.Providers.OpenAI.MaxConns, }) if cfg.Providers.Anthropic.APIKey != "" { adapters["anthropic"] = anthropic.New(anthropic.Config{ APIKey: cfg.Providers.Anthropic.APIKey, BaseURL: cfg.Providers.Anthropic.BaseURL, Version: cfg.Providers.Anthropic.Version, TimeoutSeconds: cfg.Providers.Anthropic.TimeoutSeconds, MaxConns: cfg.Providers.Anthropic.MaxConns, }) logger.Info("Anthropic adapter enabled") } if cfg.Providers.Azure.ResourceName != "" && cfg.Providers.Azure.APIKey != "" { adapters["azure"] = azure.New(azure.Config{ APIKey: cfg.Providers.Azure.APIKey, ResourceName: cfg.Providers.Azure.ResourceName, DeploymentID: cfg.Providers.Azure.DeploymentID, APIVersion: cfg.Providers.Azure.APIVersion, TimeoutSeconds: cfg.Providers.Azure.TimeoutSeconds, MaxConns: cfg.Providers.Azure.MaxConns, }) logger.Info("Azure OpenAI adapter enabled", zap.String("resource", cfg.Providers.Azure.ResourceName), zap.String("deployment", cfg.Providers.Azure.DeploymentID), ) } if cfg.Providers.Mistral.APIKey != "" { adapters["mistral"] = mistral.New(mistral.Config{ APIKey: cfg.Providers.Mistral.APIKey, BaseURL: cfg.Providers.Mistral.BaseURL, TimeoutSeconds: cfg.Providers.Mistral.TimeoutSeconds, MaxConns: cfg.Providers.Mistral.MaxConns, }) logger.Info("Mistral adapter enabled") } adapters["ollama"] = ollama.New(ollama.Config{ BaseURL: cfg.Providers.Ollama.BaseURL, TimeoutSeconds: cfg.Providers.Ollama.TimeoutSeconds, MaxConns: cfg.Providers.Ollama.MaxConns, }) logger.Info("Ollama adapter enabled", zap.String("base_url", cfg.Providers.Ollama.BaseURL)) // ── Database (PostgreSQL via pgx) ───────────────────────────────────────── var db *sql.DB if cfg.Database.URL != "" { var dbErr error db, dbErr = sql.Open("pgx", cfg.Database.URL) if dbErr != nil { logger.Fatal("failed to open database", zap.Error(dbErr)) } db.SetMaxOpenConns(cfg.Database.MaxOpenConns) db.SetMaxIdleConns(cfg.Database.MaxIdleConns) if pingErr := db.PingContext(ctx); pingErr != nil { if cfg.Server.Env == "development" { logger.Warn("database unavailable — routing engine disabled", zap.Error(pingErr)) db = nil } else { logger.Fatal("database ping failed", zap.Error(pingErr)) } } else { logger.Info("database connected", zap.String("url", cfg.Database.URL)) } } // ── Routing engine ──────────────────────────────────────────────────────── var routingEngine *routing.Engine if db != nil { ttl := time.Duration(cfg.Routing.CacheTTLSeconds) * time.Second if ttl <= 0 { ttl = 30 * time.Second } pgStore := routing.NewPgStore(db, logger) routingEngine = routing.New(pgStore, ttl, logger) routingEngine.Start() logger.Info("routing engine started", zap.Duration("cache_ttl", ttl)) } // ── Circuit breaker (E2-09) ─────────────────────────────────────────────── cb := circuitbreaker.New(5, 60*time.Second) logger.Info("circuit breaker initialised", zap.Int("threshold", 5), zap.Duration("open_ttl", 60*time.Second)) // ── Provider router (RBAC + model dispatch + optional engine) ───────────── providerRouter := router.NewWithEngineAndBreaker(adapters, &cfg.RBAC, routingEngine, cb, logger) logger.Info("provider router initialised", zap.Int("adapter_count", len(adapters)), zap.Strings("user_allowed_models", cfg.RBAC.UserAllowedModels), zap.Bool("routing_engine", routingEngine != nil), ) // ── PII client (optional) ───────────────────────────────────────────────── var piiClient *pii.Client if cfg.PII.Enabled { pc, piiErr := pii.New(pii.Config{ Address: cfg.PII.ServiceAddr, Timeout: time.Duration(cfg.PII.TimeoutMs) * time.Millisecond, FailOpen: cfg.PII.FailOpen, }, logger) if piiErr != nil { logger.Warn("PII client init failed — PII disabled", zap.Error(piiErr)) } else { piiClient = pc defer pc.Close() //nolint:errcheck logger.Info("PII client connected", zap.String("addr", cfg.PII.ServiceAddr)) } } // ── AES-256-GCM encryptor (optional) ───────────────────────────────────── var encryptor *crypto.Encryptor if cfg.Crypto.AESKeyBase64 != "" { enc, encErr := crypto.NewEncryptor(cfg.Crypto.AESKeyBase64) if encErr != nil { logger.Warn("crypto encryptor init failed — prompt encryption disabled", zap.Error(encErr)) } else { encryptor = enc logger.Info("AES-256-GCM encryptor enabled") } } else { logger.Warn("VEYLANT_CRYPTO_AES_KEY_BASE64 not set — prompt encryption disabled") } // ── Load provider configs from DB (supplements static config-based adapters) ─ if db != nil { loadProvidersFromDB(ctx, db, providerRouter, encryptor, logger) } // ── ClickHouse audit logger (optional) ──────────────────────────────────── var auditLogger auditlog.Logger if cfg.ClickHouse.DSN != "" { chLogger, chErr := auditlog.NewClickHouseLogger( cfg.ClickHouse.DSN, cfg.ClickHouse.MaxConns, cfg.ClickHouse.DialTimeoutSec, logger, ) if chErr != nil { if cfg.Server.Env == "development" { logger.Warn("ClickHouse unavailable — audit logging disabled", zap.Error(chErr)) } else { logger.Fatal("ClickHouse init failed", zap.Error(chErr)) } } else { // Apply DDL idempotently. ddlPath := "migrations/clickhouse/000001_audit_logs.sql" if ddlErr := chLogger.ApplyDDL(ddlPath); ddlErr != nil { logger.Warn("ClickHouse DDL apply failed — audit logging disabled", zap.Error(ddlErr)) } else { chLogger.Start() defer chLogger.Stop() auditLogger = chLogger logger.Info("ClickHouse audit logger started", zap.String("dsn", cfg.ClickHouse.DSN)) } } } else { logger.Warn("clickhouse.dsn not set — audit logging disabled") } // ── Feature flag store (E4-12 zero-retention + future flags + E11-07) ────── var flagStore flags.FlagStore if db != nil { flagStore = flags.NewPgFlagStore(db, logger) logger.Info("feature flag store: PostgreSQL") } else { flagStore = flags.NewMemFlagStore() logger.Warn("feature flag store: in-memory (no database)") } // Wire flag store into the provider router so it can check routing_enabled (E11-07). providerRouter.WithFlagStore(flagStore) // ── Proxy handler ───────────────────────────────────────────────────────── proxyHandler := proxy.NewWithAudit(providerRouter, logger, piiClient, auditLogger, encryptor). WithFlagStore(flagStore) // ── Rate limiter (E10-09) ───────────────────────────────────────────────── rateLimiter := ratelimit.New(ratelimit.RateLimitConfig{ RequestsPerMin: cfg.RateLimit.DefaultTenantRPM, BurstSize: cfg.RateLimit.DefaultTenantBurst, UserRPM: cfg.RateLimit.DefaultUserRPM, UserBurst: cfg.RateLimit.DefaultUserBurst, IsEnabled: true, }, logger) // Load per-tenant overrides from DB (best-effort; missing DB is graceful). if db != nil { rlStore := ratelimit.NewStore(db, logger) if overrides, err := rlStore.List(ctx); err == nil { for _, cfg := range overrides { rateLimiter.SetConfig(cfg) } logger.Info("rate limit overrides loaded", zap.Int("count", len(overrides))) } else { logger.Warn("failed to load rate limit overrides", zap.Error(err)) } } logger.Info("rate limiter initialised", zap.Int("default_tenant_rpm", cfg.RateLimit.DefaultTenantRPM), zap.Int("default_user_rpm", cfg.RateLimit.DefaultUserRPM), ) // ── HTTP router ─────────────────────────────────────────────────────────── r := chi.NewRouter() r.Use(middleware.SecurityHeaders(cfg.Server.Env)) r.Use(middleware.RequestID) r.Use(chimiddleware.RealIP) r.Use(chimiddleware.Recoverer) if cfg.Metrics.Enabled { r.Use(metrics.Middleware("openai")) } r.Get("/", health.LandingHandler) r.Get("/healthz", health.Handler) // OpenAPI documentation (E11-02). r.Get("/docs", health.DocsHTMLHandler) r.Get("/docs/openapi.yaml", health.DocsYAMLHandler) // Public PII playground — no JWT required (E8-15). r.Get("/playground", health.PlaygroundHandler) r.Post("/playground/analyze", health.PlaygroundAnalyzeHandler(piiClient, logger)) if cfg.Metrics.Enabled { r.Get(cfg.Metrics.Path, promhttp.Handler().ServeHTTP) } loginHandler := auth.NewLoginHandler(db, cfg.Auth.JWTSecret, cfg.Auth.JWTTTLHours, logger) r.Route("/v1", func(r chi.Router) { r.Use(middleware.CORS(cfg.Server.AllowedOrigins)) // Public — CORS applied, no auth required. r.Post("/auth/login", loginHandler.ServeHTTP) // Protected — JWT auth + tenant rate limit. r.Group(func(r chi.Router) { r.Use(middleware.Auth(jwtVerifier)) r.Use(middleware.RateLimit(rateLimiter)) r.Post("/chat/completions", proxyHandler.ServeHTTP) // PII analyze endpoint for Playground (E8-11, Sprint 8). piiAnalyzeHandler := pii.NewAnalyzeHandler(piiClient, logger) r.Post("/pii/analyze", piiAnalyzeHandler.ServeHTTP) // Admin API — routing policies + audit logs (Sprint 5 + Sprint 6) // + user management + provider status (Sprint 8). if routingEngine != nil { var adminHandler *admin.Handler if auditLogger != nil { adminHandler = admin.NewWithAudit( routing.NewPgStore(db, logger), routingEngine.Cache(), auditLogger, logger, ) } else { adminHandler = admin.New( routing.NewPgStore(db, logger), routingEngine.Cache(), logger, ) } // Wire db, router, rate limiter, feature flags, and encryptor. adminHandler.WithDB(db).WithRouter(providerRouter).WithRateLimiter(rateLimiter).WithFlagStore(flagStore).WithEncryptor(encryptor) r.Route("/admin", adminHandler.Routes) } // Compliance module — GDPR Art. 30 registry + AI Act classification + PDF reports (Sprint 9). if db != nil { compStore := compliance.NewPgStore(db, logger) compHandler := compliance.New(compStore, logger). WithAudit(auditLogger). WithDB(db). WithTenantName(cfg.Server.TenantName) r.Route("/admin/compliance", compHandler.Routes) logger.Info("compliance module started") } }) }) // ── HTTP server ─────────────────────────────────────────────────────────── addr := fmt.Sprintf(":%d", cfg.Server.Port) srv := &http.Server{ Addr: addr, Handler: r, ReadTimeout: 30 * time.Second, WriteTimeout: 30 * time.Second, IdleTimeout: 120 * time.Second, } quit := make(chan os.Signal, 1) signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT) go func() { logger.Info("Veylant IA proxy started", zap.String("addr", addr), zap.String("env", cfg.Server.Env), zap.Bool("metrics", cfg.Metrics.Enabled), zap.String("auth", "local-jwt"), zap.Bool("audit_logging", auditLogger != nil), zap.Bool("encryption", encryptor != nil), ) if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Fatal("server error", zap.Error(err)) } }() <-quit logger.Info("shutdown signal received, draining connections...") if routingEngine != nil { routingEngine.Stop() } timeout := time.Duration(cfg.Server.ShutdownTimeout) * time.Second shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() if err := srv.Shutdown(shutdownCtx); err != nil { logger.Error("graceful shutdown failed", zap.Error(err)) os.Exit(1) } logger.Info("server stopped cleanly") } // loadProvidersFromDB reads active provider_configs from the database and // registers/updates adapters on the router. Called once at startup so DB-saved // configs take precedence over (or extend) the static config.yaml ones. func loadProvidersFromDB(ctx context.Context, db *sql.DB, providerRouter *router.Router, enc *crypto.Encryptor, logger *zap.Logger) { rows, err := db.QueryContext(ctx, `SELECT provider, api_key_enc, base_url, resource_name, deployment_id, api_version, timeout_sec, max_conns, model_patterns FROM provider_configs WHERE is_active = TRUE`) if err != nil { logger.Warn("failed to load provider configs from DB", zap.Error(err)) return } defer rows.Close() count := 0 for rows.Next() { var ( prov, apiKeyEnc, baseURL, resourceName, deploymentID, apiVersion string timeoutSec, maxConns int patternsRaw string ) if err := rows.Scan(&prov, &apiKeyEnc, &baseURL, &resourceName, &deploymentID, &apiVersion, &timeoutSec, &maxConns, &patternsRaw); err != nil { logger.Warn("failed to scan provider config row", zap.Error(err)) continue } // Decrypt API key apiKey := apiKeyEnc if enc != nil && apiKeyEnc != "" { if dec, decErr := enc.Decrypt(apiKeyEnc); decErr == nil { apiKey = dec } } cfg := &admin.ProviderConfig{ Provider: prov, BaseURL: baseURL, ResourceName: resourceName, DeploymentID: deploymentID, APIVersion: apiVersion, TimeoutSec: timeoutSec, MaxConns: maxConns, } adapter := admin.BuildProviderAdapter(cfg, apiKey) if adapter != nil { providerRouter.UpdateAdapter(prov, adapter) count++ } } if count > 0 { logger.Info("provider configs loaded from DB", zap.Int("count", count)) } } func buildLogger(level, format string) *zap.Logger { lvl := zap.InfoLevel if err := lvl.UnmarshalText([]byte(level)); err != nil { lvl = zap.InfoLevel } encoderCfg := zap.NewProductionEncoderConfig() encoderCfg.TimeKey = "timestamp" encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder encoding := "json" if format == "console" { encoding = "console" } zapCfg := zap.Config{ Level: zap.NewAtomicLevelAt(lvl), Development: false, Encoding: encoding, EncoderConfig: encoderCfg, OutputPaths: []string{"stdout"}, ErrorOutputPaths: []string{"stderr"}, } logger, err := zapCfg.Build() if err != nil { panic(fmt.Sprintf("failed to build logger: %v", err)) } return logger }