veylant/internal/admin/provider_configs.go
2026-03-06 18:38:04 +01:00

612 lines
19 KiB
Go

package admin
import (
"context"
"database/sql"
"encoding/json"
"net/http"
"strings"
"time"
"github.com/go-chi/chi/v5"
"go.uber.org/zap"
"github.com/veylant/ia-gateway/internal/apierror"
"github.com/veylant/ia-gateway/internal/crypto"
"github.com/veylant/ia-gateway/internal/middleware"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/provider/anthropic"
"github.com/veylant/ia-gateway/internal/provider/azure"
"github.com/veylant/ia-gateway/internal/provider/mistral"
"github.com/veylant/ia-gateway/internal/provider/ollama"
"github.com/veylant/ia-gateway/internal/provider/openai"
)
// ProviderAdapterRouter is the interface used by the admin handler to update
// adapters at runtime. Defined here to avoid import cycles with internal/router.
type ProviderAdapterRouter interface {
ProviderRouter
UpdateAdapter(name string, adapter provider.Adapter)
RemoveAdapter(name string)
}
// ProviderConfig is the DB/JSON representation of a configured LLM provider.
type ProviderConfig struct {
ID string `json:"id"`
TenantID string `json:"tenant_id"`
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
APIKeyMasked string `json:"api_key,omitempty"` // masked on read, plain on write
BaseURL string `json:"base_url,omitempty"`
ResourceName string `json:"resource_name,omitempty"`
DeploymentID string `json:"deployment_id,omitempty"`
APIVersion string `json:"api_version,omitempty"`
TimeoutSec int `json:"timeout_sec"`
MaxConns int `json:"max_conns"`
ModelPatterns []string `json:"model_patterns"`
IsActive bool `json:"is_active"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
type upsertProviderRequest struct {
Provider string `json:"provider"`
DisplayName string `json:"display_name"`
APIKey string `json:"api_key"` // plaintext, encrypted before storage
BaseURL string `json:"base_url"`
ResourceName string `json:"resource_name"`
DeploymentID string `json:"deployment_id"`
APIVersion string `json:"api_version"`
TimeoutSec int `json:"timeout_sec"`
MaxConns int `json:"max_conns"`
ModelPatterns []string `json:"model_patterns"`
IsActive *bool `json:"is_active"`
}
// providerStore wraps *sql.DB for provider_configs CRUD.
type providerStore struct {
db *sql.DB
encryptor *crypto.Encryptor
logger *zap.Logger
}
func newProviderStore(db *sql.DB, enc *crypto.Encryptor, logger *zap.Logger) *providerStore {
return &providerStore{db: db, encryptor: enc, logger: logger}
}
func (s *providerStore) list(ctx context.Context, tenantID string) ([]ProviderConfig, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT id, tenant_id, provider, display_name, api_key_enc, base_url,
resource_name, deployment_id, api_version, timeout_sec, max_conns,
model_patterns, is_active, created_at, updated_at
FROM provider_configs
WHERE tenant_id = $1
ORDER BY created_at ASC`, tenantID)
if err != nil {
return nil, err
}
defer rows.Close()
var configs []ProviderConfig
for rows.Next() {
var c ProviderConfig
var apiKeyEnc string
var patterns []string
if err := rows.Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&patterns),
&c.IsActive, &c.CreatedAt, &c.UpdatedAt); err != nil {
return nil, err
}
c.ModelPatterns = patterns
c.APIKeyMasked = maskAPIKey(apiKeyEnc, s.encryptor)
configs = append(configs, c)
}
return configs, rows.Err()
}
func (s *providerStore) get(ctx context.Context, id, tenantID string) (*ProviderConfig, string, error) {
var c ProviderConfig
var apiKeyEnc string
var patterns []string
err := s.db.QueryRowContext(ctx,
`SELECT id, tenant_id, provider, display_name, api_key_enc, base_url,
resource_name, deployment_id, api_version, timeout_sec, max_conns,
model_patterns, is_active, created_at, updated_at
FROM provider_configs WHERE id = $1 AND tenant_id = $2`, id, tenantID,
).Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&patterns),
&c.IsActive, &c.CreatedAt, &c.UpdatedAt)
if err == sql.ErrNoRows {
return nil, "", nil
}
c.ModelPatterns = patterns
return &c, apiKeyEnc, err
}
func (s *providerStore) create(ctx context.Context, tenantID string, req upsertProviderRequest) (*ProviderConfig, error) {
apiKeyEnc := encryptKey(req.APIKey, s.encryptor)
isActive := true
if req.IsActive != nil {
isActive = *req.IsActive
}
timeoutSec := req.TimeoutSec
if timeoutSec == 0 {
timeoutSec = defaultTimeout(req.Provider)
}
maxConns := req.MaxConns
if maxConns == 0 {
maxConns = 100
}
apiVersion := req.APIVersion
if apiVersion == "" {
apiVersion = "2024-02-01"
}
var c ProviderConfig
var storedPatterns []string
err := s.db.QueryRowContext(ctx,
`INSERT INTO provider_configs
(tenant_id, provider, display_name, api_key_enc, base_url, resource_name,
deployment_id, api_version, timeout_sec, max_conns, model_patterns, is_active)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11::text[],$12)
ON CONFLICT (tenant_id, provider) DO UPDATE SET
display_name = EXCLUDED.display_name,
api_key_enc = CASE WHEN EXCLUDED.api_key_enc = '' THEN provider_configs.api_key_enc ELSE EXCLUDED.api_key_enc END,
base_url = EXCLUDED.base_url,
resource_name = EXCLUDED.resource_name,
deployment_id = EXCLUDED.deployment_id,
api_version = EXCLUDED.api_version,
timeout_sec = EXCLUDED.timeout_sec,
max_conns = EXCLUDED.max_conns,
model_patterns= EXCLUDED.model_patterns,
is_active = EXCLUDED.is_active,
updated_at = NOW()
RETURNING id, tenant_id, provider, display_name, api_key_enc, base_url,
resource_name, deployment_id, api_version, timeout_sec, max_conns,
model_patterns, is_active, created_at, updated_at`,
tenantID, req.Provider, displayName(req), apiKeyEnc, req.BaseURL, req.ResourceName,
req.DeploymentID, apiVersion, timeoutSec, maxConns,
modelPatternsLiteral(req.ModelPatterns), isActive,
).Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&storedPatterns),
&c.IsActive, &c.CreatedAt, &c.UpdatedAt)
if err != nil {
return nil, err
}
c.ModelPatterns = storedPatterns
c.APIKeyMasked = maskAPIKey(apiKeyEnc, s.encryptor)
return &c, nil
}
func (s *providerStore) delete(ctx context.Context, id, tenantID string) error {
res, err := s.db.ExecContext(ctx,
`DELETE FROM provider_configs WHERE id = $1 AND tenant_id = $2`, id, tenantID)
if err != nil {
return err
}
n, _ := res.RowsAffected()
if n == 0 {
return sql.ErrNoRows
}
return nil
}
// ─── HTTP handlers ────────────────────────────────────────────────────────────
func (h *Handler) listProviderConfigs(w http.ResponseWriter, r *http.Request) {
tenantID, ok := tenantFromCtx(w, r)
if !ok {
return
}
if h.db == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
return
}
ps := newProviderStore(h.db, h.encryptor, h.logger)
configs, err := ps.list(r.Context(), tenantID)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError("failed to list providers: "+err.Error()))
return
}
if configs == nil {
configs = []ProviderConfig{}
}
// Enrich with circuit breaker status if available
type enriched struct {
ProviderConfig
State string `json:"state,omitempty"`
Failures int `json:"failures,omitempty"`
}
statuses := map[string]string{}
failures := map[string]int{}
if h.router != nil {
for _, s := range h.router.ProviderStatuses() {
statuses[s.Provider] = s.State
failures[s.Provider] = s.Failures
}
}
result := make([]enriched, len(configs))
for i, c := range configs {
result[i] = enriched{
ProviderConfig: c,
State: statuses[c.Provider],
Failures: failures[c.Provider],
}
}
writeJSON(w, http.StatusOK, map[string]interface{}{"data": result})
}
func (h *Handler) createProviderConfig(w http.ResponseWriter, r *http.Request) {
tenantID, ok := tenantFromCtx(w, r)
if !ok {
return
}
if h.db == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
return
}
var req upsertProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON: "+err.Error()))
return
}
if !validProvider(req.Provider) {
apierror.WriteError(w, apierror.NewBadRequestError("provider must be one of: openai, anthropic, azure, mistral, ollama"))
return
}
ps := newProviderStore(h.db, h.encryptor, h.logger)
cfg, err := ps.create(r.Context(), tenantID, req)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError("failed to save provider: "+err.Error()))
return
}
// Rebuild and hot-reload the adapter
h.reloadAdapter(r.Context(), cfg, req.APIKey)
h.logger.Info("provider config saved", zap.String("provider", cfg.Provider), zap.String("tenant", tenantID))
writeJSON(w, http.StatusCreated, cfg)
}
func (h *Handler) updateProviderConfig(w http.ResponseWriter, r *http.Request) {
tenantID, ok := tenantFromCtx(w, r)
if !ok {
return
}
if h.db == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
return
}
id := chi.URLParam(r, "id")
var req upsertProviderRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON: "+err.Error()))
return
}
ps := newProviderStore(h.db, h.encryptor, h.logger)
// Fetch existing to get provider name if not supplied
existing, existingKeyEnc, err := ps.get(r.Context(), id, tenantID)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
return
}
if existing == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
return
}
if req.Provider == "" {
req.Provider = existing.Provider
}
// If no new API key provided, keep the existing encrypted one by passing empty string
// (the upsert handles this via the CASE statement)
decryptedKey := req.APIKey
if decryptedKey == "" && existingKeyEnc != "" {
// Decrypt existing key to pass to adapter rebuild
if h.encryptor != nil {
if dec, decErr := h.encryptor.Decrypt(existingKeyEnc); decErr == nil {
decryptedKey = dec
}
} else {
decryptedKey = existingKeyEnc
}
}
cfg, err := ps.create(r.Context(), tenantID, req) // upsert
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError("failed to update provider: "+err.Error()))
return
}
h.reloadAdapter(r.Context(), cfg, decryptedKey)
h.logger.Info("provider config updated", zap.String("provider", cfg.Provider), zap.String("id", id))
writeJSON(w, http.StatusOK, cfg)
}
func (h *Handler) deleteProviderConfig(w http.ResponseWriter, r *http.Request) {
tenantID, ok := tenantFromCtx(w, r)
if !ok {
return
}
if h.db == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
return
}
id := chi.URLParam(r, "id")
ps := newProviderStore(h.db, h.encryptor, h.logger)
existing, _, err := ps.get(r.Context(), id, tenantID)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
return
}
if existing == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
return
}
if err := ps.delete(r.Context(), id, tenantID); err != nil {
apierror.WriteError(w, apierror.NewUpstreamError("failed to delete provider: "+err.Error()))
return
}
// Remove adapter from router (config-based providers survive if they have an API key in env)
if h.adapterRouter != nil {
h.adapterRouter.RemoveAdapter(existing.Provider)
}
h.logger.Info("provider config deleted", zap.String("provider", existing.Provider))
w.WriteHeader(http.StatusNoContent)
}
func (h *Handler) testProviderConfig(w http.ResponseWriter, r *http.Request) {
tenantID, ok := tenantFromCtx(w, r)
if !ok {
return
}
if h.db == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
return
}
id := chi.URLParam(r, "id")
ps := newProviderStore(h.db, h.encryptor, h.logger)
existing, apiKeyEnc, err := ps.get(r.Context(), id, tenantID)
if err != nil {
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
return
}
if existing == nil {
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
return
}
apiKey := ""
if h.encryptor != nil && apiKeyEnc != "" {
if dec, decErr := h.encryptor.Decrypt(apiKeyEnc); decErr == nil {
apiKey = dec
}
} else {
apiKey = apiKeyEnc
}
adapter := buildAdapter(existing, apiKey)
if adapter == nil {
apierror.WriteError(w, apierror.NewBadRequestError("cannot build adapter for provider "+existing.Provider))
return
}
if err := adapter.HealthCheck(r.Context()); err != nil {
writeJSON(w, http.StatusOK, map[string]interface{}{
"healthy": false,
"error": err.Error(),
})
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{"healthy": true})
}
// ─── Adapter helpers ──────────────────────────────────────────────────────────
// reloadAdapter rebuilds the provider adapter from cfg and registers it on the router.
func (h *Handler) reloadAdapter(_ context.Context, cfg *ProviderConfig, apiKey string) {
if h.adapterRouter == nil {
return
}
adapter := buildAdapter(cfg, apiKey)
if adapter != nil {
h.adapterRouter.UpdateAdapter(cfg.Provider, adapter)
}
}
// BuildProviderAdapter constructs the correct provider.Adapter from a ProviderConfig.
// Exported so main.go can call it when loading provider configs from DB at startup.
func BuildProviderAdapter(cfg *ProviderConfig, apiKey string) provider.Adapter {
return buildAdapter(cfg, apiKey)
}
// buildAdapter constructs the correct provider.Adapter from a ProviderConfig.
func buildAdapter(cfg *ProviderConfig, apiKey string) provider.Adapter {
switch cfg.Provider {
case "openai":
return openai.New(openai.Config{
APIKey: apiKey,
BaseURL: orDefault(cfg.BaseURL, "https://api.openai.com/v1"),
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
MaxConns: orDefaultInt(cfg.MaxConns, 100),
})
case "anthropic":
return anthropic.New(anthropic.Config{
APIKey: apiKey,
BaseURL: orDefault(cfg.BaseURL, "https://api.anthropic.com/v1"),
Version: "2023-06-01",
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
MaxConns: orDefaultInt(cfg.MaxConns, 100),
})
case "azure":
return azure.New(azure.Config{
APIKey: apiKey,
ResourceName: cfg.ResourceName,
DeploymentID: cfg.DeploymentID,
APIVersion: orDefault(cfg.APIVersion, "2024-02-01"),
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
MaxConns: orDefaultInt(cfg.MaxConns, 100),
})
case "mistral":
return mistral.New(mistral.Config{
APIKey: apiKey,
BaseURL: orDefault(cfg.BaseURL, "https://api.mistral.ai/v1"),
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
MaxConns: orDefaultInt(cfg.MaxConns, 100),
})
case "ollama":
return ollama.New(ollama.Config{
BaseURL: orDefault(cfg.BaseURL, "http://localhost:11434/v1"),
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 120),
MaxConns: orDefaultInt(cfg.MaxConns, 10),
})
}
return nil
}
// ─── Helpers ──────────────────────────────────────────────────────────────────
func validProvider(p string) bool {
switch p {
case "openai", "anthropic", "azure", "mistral", "ollama":
return true
}
return false
}
func defaultTimeout(provider string) int {
if provider == "ollama" {
return 120
}
return 30
}
func displayName(req upsertProviderRequest) string {
if req.DisplayName != "" {
return req.DisplayName
}
return strings.Title(req.Provider) //nolint:staticcheck
}
func encryptKey(key string, enc *crypto.Encryptor) string {
if key == "" {
return ""
}
if enc == nil {
return key // dev mode: store plaintext
}
encrypted, err := enc.Encrypt(key)
if err != nil {
return key // fallback to plaintext on error
}
return encrypted
}
func maskAPIKey(apiKeyEnc string, enc *crypto.Encryptor) string {
if apiKeyEnc == "" {
return ""
}
// Decrypt to get the real key, then mask it
plaintext := apiKeyEnc
if enc != nil {
if dec, err := enc.Decrypt(apiKeyEnc); err == nil {
plaintext = dec
}
}
if len(plaintext) <= 8 {
return strings.Repeat("•", len(plaintext))
}
return plaintext[:8] + "••••••••"
}
func orDefault(s, def string) string {
if s == "" {
return def
}
return s
}
func orDefaultInt(v, def int) int {
if v == 0 {
return def
}
return v
}
// pqStringArray is a simple []string scanner for PostgreSQL TEXT[] columns.
type pqStringArray []string
func (a *pqStringArray) Scan(src interface{}) error {
if src == nil {
*a = []string{}
return nil
}
// PostgreSQL returns TEXT[] as "{val1,val2,...}" string
s, ok := src.(string)
if !ok {
if b, ok := src.([]byte); ok {
s = string(b)
} else {
*a = []string{}
return nil
}
}
s = strings.TrimPrefix(s, "{")
s = strings.TrimSuffix(s, "}")
if s == "" {
*a = []string{}
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.Trim(p, `"`)
if p != "" {
result = append(result, p)
}
}
*a = result
return nil
}
// Value implements driver.Valuer for pqStringArray (for INSERT).
func (a pqStringArray) Value() (interface{}, error) {
if len(a) == 0 {
return "{}", nil
}
parts := make([]string, len(a))
for i, v := range a {
parts[i] = `"` + strings.ReplaceAll(v, `"`, `\"`) + `"`
}
return "{" + strings.Join(parts, ",") + "}", nil
}
// modelPatternsLiteral converts a []string to a PostgreSQL array literal string
// (e.g. {"a","b"}) to be used with $N::text[] in queries.
// This avoids driver.Valuer issues with the pgx stdlib adapter.
func modelPatternsLiteral(patterns []string) string {
if len(patterns) == 0 {
return "{}"
}
parts := make([]string, len(patterns))
for i, v := range patterns {
parts[i] = `"` + strings.ReplaceAll(v, `"`, `\"`) + `"`
}
return "{" + strings.Join(parts, ",") + "}"
}
// claimsFromRequest is a local helper (avoids import of middleware in this file).
func claimsFromRequest(r *http.Request) *middleware.UserClaims {
claims, _ := middleware.ClaimsFromContext(r.Context())
return claims
}