434 lines
16 KiB
Go
434 lines
16 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"os/signal"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
chimiddleware "github.com/go-chi/chi/v5/middleware"
|
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
|
"go.uber.org/zap"
|
|
"go.uber.org/zap/zapcore"
|
|
|
|
_ "github.com/jackc/pgx/v5/stdlib" // register pgx driver
|
|
|
|
"github.com/veylant/ia-gateway/internal/admin"
|
|
"github.com/veylant/ia-gateway/internal/auditlog"
|
|
"github.com/veylant/ia-gateway/internal/circuitbreaker"
|
|
"github.com/veylant/ia-gateway/internal/compliance"
|
|
"github.com/veylant/ia-gateway/internal/config"
|
|
"github.com/veylant/ia-gateway/internal/crypto"
|
|
"github.com/veylant/ia-gateway/internal/flags"
|
|
"github.com/veylant/ia-gateway/internal/health"
|
|
"github.com/veylant/ia-gateway/internal/metrics"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/pii"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/provider/anthropic"
|
|
"github.com/veylant/ia-gateway/internal/provider/azure"
|
|
"github.com/veylant/ia-gateway/internal/provider/mistral"
|
|
"github.com/veylant/ia-gateway/internal/provider/ollama"
|
|
"github.com/veylant/ia-gateway/internal/provider/openai"
|
|
"github.com/veylant/ia-gateway/internal/proxy"
|
|
"github.com/veylant/ia-gateway/internal/ratelimit"
|
|
"github.com/veylant/ia-gateway/internal/router"
|
|
"github.com/veylant/ia-gateway/internal/routing"
|
|
)
|
|
|
|
func main() {
|
|
cfg, err := config.Load()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "failed to load config: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
|
|
logger := buildLogger(cfg.Log.Level, cfg.Log.Format)
|
|
defer logger.Sync() //nolint:errcheck
|
|
|
|
// ── JWT / OIDC verifier ───────────────────────────────────────────────────
|
|
issuerURL := fmt.Sprintf("%s/realms/%s", cfg.Keycloak.BaseURL, cfg.Keycloak.Realm)
|
|
logger.Info("initialising OIDC verifier", zap.String("issuer", issuerURL))
|
|
|
|
ctx := context.Background()
|
|
oidcVerifier, err := middleware.NewOIDCVerifier(ctx, issuerURL, cfg.Keycloak.ClientID)
|
|
if err != nil {
|
|
if cfg.Server.Env == "development" {
|
|
logger.Warn("OIDC verifier unavailable — JWT auth will reject all requests",
|
|
zap.Error(err))
|
|
oidcVerifier = nil
|
|
} else {
|
|
logger.Fatal("failed to initialise OIDC verifier", zap.Error(err))
|
|
}
|
|
}
|
|
|
|
// ── LLM provider adapters ─────────────────────────────────────────────────
|
|
adapters := map[string]provider.Adapter{}
|
|
|
|
adapters["openai"] = openai.New(openai.Config{
|
|
APIKey: cfg.Providers.OpenAI.APIKey,
|
|
BaseURL: cfg.Providers.OpenAI.BaseURL,
|
|
TimeoutSeconds: cfg.Providers.OpenAI.TimeoutSeconds,
|
|
MaxConns: cfg.Providers.OpenAI.MaxConns,
|
|
})
|
|
|
|
if cfg.Providers.Anthropic.APIKey != "" {
|
|
adapters["anthropic"] = anthropic.New(anthropic.Config{
|
|
APIKey: cfg.Providers.Anthropic.APIKey,
|
|
BaseURL: cfg.Providers.Anthropic.BaseURL,
|
|
Version: cfg.Providers.Anthropic.Version,
|
|
TimeoutSeconds: cfg.Providers.Anthropic.TimeoutSeconds,
|
|
MaxConns: cfg.Providers.Anthropic.MaxConns,
|
|
})
|
|
logger.Info("Anthropic adapter enabled")
|
|
}
|
|
|
|
if cfg.Providers.Azure.ResourceName != "" && cfg.Providers.Azure.APIKey != "" {
|
|
adapters["azure"] = azure.New(azure.Config{
|
|
APIKey: cfg.Providers.Azure.APIKey,
|
|
ResourceName: cfg.Providers.Azure.ResourceName,
|
|
DeploymentID: cfg.Providers.Azure.DeploymentID,
|
|
APIVersion: cfg.Providers.Azure.APIVersion,
|
|
TimeoutSeconds: cfg.Providers.Azure.TimeoutSeconds,
|
|
MaxConns: cfg.Providers.Azure.MaxConns,
|
|
})
|
|
logger.Info("Azure OpenAI adapter enabled",
|
|
zap.String("resource", cfg.Providers.Azure.ResourceName),
|
|
zap.String("deployment", cfg.Providers.Azure.DeploymentID),
|
|
)
|
|
}
|
|
|
|
if cfg.Providers.Mistral.APIKey != "" {
|
|
adapters["mistral"] = mistral.New(mistral.Config{
|
|
APIKey: cfg.Providers.Mistral.APIKey,
|
|
BaseURL: cfg.Providers.Mistral.BaseURL,
|
|
TimeoutSeconds: cfg.Providers.Mistral.TimeoutSeconds,
|
|
MaxConns: cfg.Providers.Mistral.MaxConns,
|
|
})
|
|
logger.Info("Mistral adapter enabled")
|
|
}
|
|
|
|
adapters["ollama"] = ollama.New(ollama.Config{
|
|
BaseURL: cfg.Providers.Ollama.BaseURL,
|
|
TimeoutSeconds: cfg.Providers.Ollama.TimeoutSeconds,
|
|
MaxConns: cfg.Providers.Ollama.MaxConns,
|
|
})
|
|
logger.Info("Ollama adapter enabled", zap.String("base_url", cfg.Providers.Ollama.BaseURL))
|
|
|
|
// ── Database (PostgreSQL via pgx) ─────────────────────────────────────────
|
|
var db *sql.DB
|
|
if cfg.Database.URL != "" {
|
|
var dbErr error
|
|
db, dbErr = sql.Open("pgx", cfg.Database.URL)
|
|
if dbErr != nil {
|
|
logger.Fatal("failed to open database", zap.Error(dbErr))
|
|
}
|
|
db.SetMaxOpenConns(cfg.Database.MaxOpenConns)
|
|
db.SetMaxIdleConns(cfg.Database.MaxIdleConns)
|
|
if pingErr := db.PingContext(ctx); pingErr != nil {
|
|
if cfg.Server.Env == "development" {
|
|
logger.Warn("database unavailable — routing engine disabled", zap.Error(pingErr))
|
|
db = nil
|
|
} else {
|
|
logger.Fatal("database ping failed", zap.Error(pingErr))
|
|
}
|
|
} else {
|
|
logger.Info("database connected", zap.String("url", cfg.Database.URL))
|
|
}
|
|
}
|
|
|
|
// ── Routing engine ────────────────────────────────────────────────────────
|
|
var routingEngine *routing.Engine
|
|
if db != nil {
|
|
ttl := time.Duration(cfg.Routing.CacheTTLSeconds) * time.Second
|
|
if ttl <= 0 {
|
|
ttl = 30 * time.Second
|
|
}
|
|
pgStore := routing.NewPgStore(db, logger)
|
|
routingEngine = routing.New(pgStore, ttl, logger)
|
|
routingEngine.Start()
|
|
logger.Info("routing engine started", zap.Duration("cache_ttl", ttl))
|
|
}
|
|
|
|
// ── Circuit breaker (E2-09) ───────────────────────────────────────────────
|
|
cb := circuitbreaker.New(5, 60*time.Second)
|
|
logger.Info("circuit breaker initialised", zap.Int("threshold", 5), zap.Duration("open_ttl", 60*time.Second))
|
|
|
|
// ── Provider router (RBAC + model dispatch + optional engine) ─────────────
|
|
providerRouter := router.NewWithEngineAndBreaker(adapters, &cfg.RBAC, routingEngine, cb, logger)
|
|
logger.Info("provider router initialised",
|
|
zap.Int("adapter_count", len(adapters)),
|
|
zap.Strings("user_allowed_models", cfg.RBAC.UserAllowedModels),
|
|
zap.Bool("routing_engine", routingEngine != nil),
|
|
)
|
|
|
|
// ── PII client (optional) ─────────────────────────────────────────────────
|
|
var piiClient *pii.Client
|
|
if cfg.PII.Enabled {
|
|
pc, piiErr := pii.New(pii.Config{
|
|
Address: cfg.PII.ServiceAddr,
|
|
Timeout: time.Duration(cfg.PII.TimeoutMs) * time.Millisecond,
|
|
FailOpen: cfg.PII.FailOpen,
|
|
}, logger)
|
|
if piiErr != nil {
|
|
logger.Warn("PII client init failed — PII disabled", zap.Error(piiErr))
|
|
} else {
|
|
piiClient = pc
|
|
defer pc.Close() //nolint:errcheck
|
|
logger.Info("PII client connected", zap.String("addr", cfg.PII.ServiceAddr))
|
|
}
|
|
}
|
|
|
|
// ── AES-256-GCM encryptor (optional) ─────────────────────────────────────
|
|
var encryptor *crypto.Encryptor
|
|
if cfg.Crypto.AESKeyBase64 != "" {
|
|
enc, encErr := crypto.NewEncryptor(cfg.Crypto.AESKeyBase64)
|
|
if encErr != nil {
|
|
logger.Warn("crypto encryptor init failed — prompt encryption disabled", zap.Error(encErr))
|
|
} else {
|
|
encryptor = enc
|
|
logger.Info("AES-256-GCM encryptor enabled")
|
|
}
|
|
} else {
|
|
logger.Warn("VEYLANT_CRYPTO_AES_KEY_BASE64 not set — prompt encryption disabled")
|
|
}
|
|
|
|
// ── ClickHouse audit logger (optional) ────────────────────────────────────
|
|
var auditLogger auditlog.Logger
|
|
if cfg.ClickHouse.DSN != "" {
|
|
chLogger, chErr := auditlog.NewClickHouseLogger(
|
|
cfg.ClickHouse.DSN,
|
|
cfg.ClickHouse.MaxConns,
|
|
cfg.ClickHouse.DialTimeoutSec,
|
|
logger,
|
|
)
|
|
if chErr != nil {
|
|
if cfg.Server.Env == "development" {
|
|
logger.Warn("ClickHouse unavailable — audit logging disabled", zap.Error(chErr))
|
|
} else {
|
|
logger.Fatal("ClickHouse init failed", zap.Error(chErr))
|
|
}
|
|
} else {
|
|
// Apply DDL idempotently.
|
|
ddlPath := "migrations/clickhouse/000001_audit_logs.sql"
|
|
if ddlErr := chLogger.ApplyDDL(ddlPath); ddlErr != nil {
|
|
logger.Warn("ClickHouse DDL apply failed — audit logging disabled", zap.Error(ddlErr))
|
|
} else {
|
|
chLogger.Start()
|
|
defer chLogger.Stop()
|
|
auditLogger = chLogger
|
|
logger.Info("ClickHouse audit logger started", zap.String("dsn", cfg.ClickHouse.DSN))
|
|
}
|
|
}
|
|
} else {
|
|
logger.Warn("clickhouse.dsn not set — audit logging disabled")
|
|
}
|
|
|
|
// ── Feature flag store (E4-12 zero-retention + future flags + E11-07) ──────
|
|
var flagStore flags.FlagStore
|
|
if db != nil {
|
|
flagStore = flags.NewPgFlagStore(db, logger)
|
|
logger.Info("feature flag store: PostgreSQL")
|
|
} else {
|
|
flagStore = flags.NewMemFlagStore()
|
|
logger.Warn("feature flag store: in-memory (no database)")
|
|
}
|
|
// Wire flag store into the provider router so it can check routing_enabled (E11-07).
|
|
providerRouter.WithFlagStore(flagStore)
|
|
|
|
// ── Proxy handler ─────────────────────────────────────────────────────────
|
|
proxyHandler := proxy.NewWithAudit(providerRouter, logger, piiClient, auditLogger, encryptor).
|
|
WithFlagStore(flagStore)
|
|
|
|
// ── Rate limiter (E10-09) ─────────────────────────────────────────────────
|
|
rateLimiter := ratelimit.New(ratelimit.RateLimitConfig{
|
|
RequestsPerMin: cfg.RateLimit.DefaultTenantRPM,
|
|
BurstSize: cfg.RateLimit.DefaultTenantBurst,
|
|
UserRPM: cfg.RateLimit.DefaultUserRPM,
|
|
UserBurst: cfg.RateLimit.DefaultUserBurst,
|
|
IsEnabled: true,
|
|
}, logger)
|
|
// Load per-tenant overrides from DB (best-effort; missing DB is graceful).
|
|
if db != nil {
|
|
rlStore := ratelimit.NewStore(db, logger)
|
|
if overrides, err := rlStore.List(ctx); err == nil {
|
|
for _, cfg := range overrides {
|
|
rateLimiter.SetConfig(cfg)
|
|
}
|
|
logger.Info("rate limit overrides loaded", zap.Int("count", len(overrides)))
|
|
} else {
|
|
logger.Warn("failed to load rate limit overrides", zap.Error(err))
|
|
}
|
|
}
|
|
logger.Info("rate limiter initialised",
|
|
zap.Int("default_tenant_rpm", cfg.RateLimit.DefaultTenantRPM),
|
|
zap.Int("default_user_rpm", cfg.RateLimit.DefaultUserRPM),
|
|
)
|
|
|
|
// ── HTTP router ───────────────────────────────────────────────────────────
|
|
r := chi.NewRouter()
|
|
|
|
r.Use(middleware.SecurityHeaders(cfg.Server.Env))
|
|
r.Use(middleware.RequestID)
|
|
r.Use(chimiddleware.RealIP)
|
|
r.Use(chimiddleware.Recoverer)
|
|
|
|
if cfg.Metrics.Enabled {
|
|
r.Use(metrics.Middleware("openai"))
|
|
}
|
|
|
|
r.Get("/healthz", health.Handler)
|
|
|
|
// OpenAPI documentation (E11-02).
|
|
r.Get("/docs", health.DocsHTMLHandler)
|
|
r.Get("/docs/openapi.yaml", health.DocsYAMLHandler)
|
|
|
|
// Public PII playground — no JWT required (E8-15).
|
|
r.Get("/playground", health.PlaygroundHandler)
|
|
r.Post("/playground/analyze", health.PlaygroundAnalyzeHandler(piiClient, logger))
|
|
|
|
if cfg.Metrics.Enabled {
|
|
r.Get(cfg.Metrics.Path, promhttp.Handler().ServeHTTP)
|
|
}
|
|
|
|
r.Route("/v1", func(r chi.Router) {
|
|
r.Use(middleware.CORS(cfg.Server.AllowedOrigins))
|
|
var authMW func(http.Handler) http.Handler
|
|
if oidcVerifier != nil {
|
|
authMW = middleware.Auth(oidcVerifier)
|
|
} else {
|
|
authMW = middleware.Auth(&middleware.MockVerifier{
|
|
Claims: &middleware.UserClaims{
|
|
UserID: "dev-user",
|
|
TenantID: "00000000-0000-0000-0000-000000000001",
|
|
Email: "dev@veylant.local",
|
|
Roles: []string{"admin"},
|
|
},
|
|
})
|
|
logger.Warn("running in DEV mode — JWT validation is DISABLED")
|
|
}
|
|
r.Use(authMW)
|
|
r.Use(middleware.RateLimit(rateLimiter))
|
|
r.Post("/chat/completions", proxyHandler.ServeHTTP)
|
|
|
|
// PII analyze endpoint for Playground (E8-11, Sprint 8).
|
|
piiAnalyzeHandler := pii.NewAnalyzeHandler(piiClient, logger)
|
|
r.Post("/pii/analyze", piiAnalyzeHandler.ServeHTTP)
|
|
|
|
// Admin API — routing policies + audit logs (Sprint 5 + Sprint 6)
|
|
// + user management + provider status (Sprint 8).
|
|
if routingEngine != nil {
|
|
var adminHandler *admin.Handler
|
|
if auditLogger != nil {
|
|
adminHandler = admin.NewWithAudit(
|
|
routing.NewPgStore(db, logger),
|
|
routingEngine.Cache(),
|
|
auditLogger,
|
|
logger,
|
|
)
|
|
} else {
|
|
adminHandler = admin.New(
|
|
routing.NewPgStore(db, logger),
|
|
routingEngine.Cache(),
|
|
logger,
|
|
)
|
|
}
|
|
// Wire db, router, rate limiter, and feature flags (Sprint 8 + Sprint 10 + Sprint 11).
|
|
adminHandler.WithDB(db).WithRouter(providerRouter).WithRateLimiter(rateLimiter).WithFlagStore(flagStore)
|
|
r.Route("/admin", adminHandler.Routes)
|
|
}
|
|
|
|
// Compliance module — GDPR Art. 30 registry + AI Act classification + PDF reports (Sprint 9).
|
|
if db != nil {
|
|
compStore := compliance.NewPgStore(db, logger)
|
|
compHandler := compliance.New(compStore, logger).
|
|
WithAudit(auditLogger).
|
|
WithDB(db).
|
|
WithTenantName(cfg.Server.TenantName)
|
|
r.Route("/admin/compliance", compHandler.Routes)
|
|
logger.Info("compliance module started")
|
|
}
|
|
})
|
|
|
|
// ── HTTP server ───────────────────────────────────────────────────────────
|
|
addr := fmt.Sprintf(":%d", cfg.Server.Port)
|
|
srv := &http.Server{
|
|
Addr: addr,
|
|
Handler: r,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
quit := make(chan os.Signal, 1)
|
|
signal.Notify(quit, syscall.SIGTERM, syscall.SIGINT)
|
|
|
|
go func() {
|
|
logger.Info("Veylant IA proxy started",
|
|
zap.String("addr", addr),
|
|
zap.String("env", cfg.Server.Env),
|
|
zap.Bool("metrics", cfg.Metrics.Enabled),
|
|
zap.String("oidc_issuer", issuerURL),
|
|
zap.Bool("audit_logging", auditLogger != nil),
|
|
zap.Bool("encryption", encryptor != nil),
|
|
)
|
|
if err := srv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
logger.Fatal("server error", zap.Error(err))
|
|
}
|
|
}()
|
|
|
|
<-quit
|
|
logger.Info("shutdown signal received, draining connections...")
|
|
|
|
if routingEngine != nil {
|
|
routingEngine.Stop()
|
|
}
|
|
|
|
timeout := time.Duration(cfg.Server.ShutdownTimeout) * time.Second
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
defer cancel()
|
|
|
|
if err := srv.Shutdown(shutdownCtx); err != nil {
|
|
logger.Error("graceful shutdown failed", zap.Error(err))
|
|
os.Exit(1)
|
|
}
|
|
logger.Info("server stopped cleanly")
|
|
}
|
|
|
|
func buildLogger(level, format string) *zap.Logger {
|
|
lvl := zap.InfoLevel
|
|
if err := lvl.UnmarshalText([]byte(level)); err != nil {
|
|
lvl = zap.InfoLevel
|
|
}
|
|
|
|
encoderCfg := zap.NewProductionEncoderConfig()
|
|
encoderCfg.TimeKey = "timestamp"
|
|
encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
|
|
|
|
encoding := "json"
|
|
if format == "console" {
|
|
encoding = "console"
|
|
}
|
|
|
|
zapCfg := zap.Config{
|
|
Level: zap.NewAtomicLevelAt(lvl),
|
|
Development: false,
|
|
Encoding: encoding,
|
|
EncoderConfig: encoderCfg,
|
|
OutputPaths: []string{"stdout"},
|
|
ErrorOutputPaths: []string{"stderr"},
|
|
}
|
|
|
|
logger, err := zapCfg.Build()
|
|
if err != nil {
|
|
panic(fmt.Sprintf("failed to build logger: %v", err))
|
|
}
|
|
return logger
|
|
}
|