612 lines
19 KiB
Go
612 lines
19 KiB
Go
package admin
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
"github.com/veylant/ia-gateway/internal/crypto"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/provider/anthropic"
|
|
"github.com/veylant/ia-gateway/internal/provider/azure"
|
|
"github.com/veylant/ia-gateway/internal/provider/mistral"
|
|
"github.com/veylant/ia-gateway/internal/provider/ollama"
|
|
"github.com/veylant/ia-gateway/internal/provider/openai"
|
|
)
|
|
|
|
// ProviderAdapterRouter is the interface used by the admin handler to update
|
|
// adapters at runtime. Defined here to avoid import cycles with internal/router.
|
|
type ProviderAdapterRouter interface {
|
|
ProviderRouter
|
|
UpdateAdapter(name string, adapter provider.Adapter)
|
|
RemoveAdapter(name string)
|
|
}
|
|
|
|
// ProviderConfig is the DB/JSON representation of a configured LLM provider.
|
|
type ProviderConfig struct {
|
|
ID string `json:"id"`
|
|
TenantID string `json:"tenant_id"`
|
|
Provider string `json:"provider"`
|
|
DisplayName string `json:"display_name"`
|
|
APIKeyMasked string `json:"api_key,omitempty"` // masked on read, plain on write
|
|
BaseURL string `json:"base_url,omitempty"`
|
|
ResourceName string `json:"resource_name,omitempty"`
|
|
DeploymentID string `json:"deployment_id,omitempty"`
|
|
APIVersion string `json:"api_version,omitempty"`
|
|
TimeoutSec int `json:"timeout_sec"`
|
|
MaxConns int `json:"max_conns"`
|
|
ModelPatterns []string `json:"model_patterns"`
|
|
IsActive bool `json:"is_active"`
|
|
CreatedAt time.Time `json:"created_at"`
|
|
UpdatedAt time.Time `json:"updated_at"`
|
|
}
|
|
|
|
type upsertProviderRequest struct {
|
|
Provider string `json:"provider"`
|
|
DisplayName string `json:"display_name"`
|
|
APIKey string `json:"api_key"` // plaintext, encrypted before storage
|
|
BaseURL string `json:"base_url"`
|
|
ResourceName string `json:"resource_name"`
|
|
DeploymentID string `json:"deployment_id"`
|
|
APIVersion string `json:"api_version"`
|
|
TimeoutSec int `json:"timeout_sec"`
|
|
MaxConns int `json:"max_conns"`
|
|
ModelPatterns []string `json:"model_patterns"`
|
|
IsActive *bool `json:"is_active"`
|
|
}
|
|
|
|
// providerStore wraps *sql.DB for provider_configs CRUD.
|
|
type providerStore struct {
|
|
db *sql.DB
|
|
encryptor *crypto.Encryptor
|
|
logger *zap.Logger
|
|
}
|
|
|
|
func newProviderStore(db *sql.DB, enc *crypto.Encryptor, logger *zap.Logger) *providerStore {
|
|
return &providerStore{db: db, encryptor: enc, logger: logger}
|
|
}
|
|
|
|
func (s *providerStore) list(ctx context.Context, tenantID string) ([]ProviderConfig, error) {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
`SELECT id, tenant_id, provider, display_name, api_key_enc, base_url,
|
|
resource_name, deployment_id, api_version, timeout_sec, max_conns,
|
|
model_patterns, is_active, created_at, updated_at
|
|
FROM provider_configs
|
|
WHERE tenant_id = $1
|
|
ORDER BY created_at ASC`, tenantID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var configs []ProviderConfig
|
|
for rows.Next() {
|
|
var c ProviderConfig
|
|
var apiKeyEnc string
|
|
var patterns []string
|
|
if err := rows.Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
|
|
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
|
|
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&patterns),
|
|
&c.IsActive, &c.CreatedAt, &c.UpdatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
c.ModelPatterns = patterns
|
|
c.APIKeyMasked = maskAPIKey(apiKeyEnc, s.encryptor)
|
|
configs = append(configs, c)
|
|
}
|
|
return configs, rows.Err()
|
|
}
|
|
|
|
func (s *providerStore) get(ctx context.Context, id, tenantID string) (*ProviderConfig, string, error) {
|
|
var c ProviderConfig
|
|
var apiKeyEnc string
|
|
var patterns []string
|
|
err := s.db.QueryRowContext(ctx,
|
|
`SELECT id, tenant_id, provider, display_name, api_key_enc, base_url,
|
|
resource_name, deployment_id, api_version, timeout_sec, max_conns,
|
|
model_patterns, is_active, created_at, updated_at
|
|
FROM provider_configs WHERE id = $1 AND tenant_id = $2`, id, tenantID,
|
|
).Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
|
|
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
|
|
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&patterns),
|
|
&c.IsActive, &c.CreatedAt, &c.UpdatedAt)
|
|
if err == sql.ErrNoRows {
|
|
return nil, "", nil
|
|
}
|
|
c.ModelPatterns = patterns
|
|
return &c, apiKeyEnc, err
|
|
}
|
|
|
|
func (s *providerStore) create(ctx context.Context, tenantID string, req upsertProviderRequest) (*ProviderConfig, error) {
|
|
apiKeyEnc := encryptKey(req.APIKey, s.encryptor)
|
|
isActive := true
|
|
if req.IsActive != nil {
|
|
isActive = *req.IsActive
|
|
}
|
|
timeoutSec := req.TimeoutSec
|
|
if timeoutSec == 0 {
|
|
timeoutSec = defaultTimeout(req.Provider)
|
|
}
|
|
maxConns := req.MaxConns
|
|
if maxConns == 0 {
|
|
maxConns = 100
|
|
}
|
|
apiVersion := req.APIVersion
|
|
if apiVersion == "" {
|
|
apiVersion = "2024-02-01"
|
|
}
|
|
|
|
var c ProviderConfig
|
|
var storedPatterns []string
|
|
err := s.db.QueryRowContext(ctx,
|
|
`INSERT INTO provider_configs
|
|
(tenant_id, provider, display_name, api_key_enc, base_url, resource_name,
|
|
deployment_id, api_version, timeout_sec, max_conns, model_patterns, is_active)
|
|
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11::text[],$12)
|
|
ON CONFLICT (tenant_id, provider) DO UPDATE SET
|
|
display_name = EXCLUDED.display_name,
|
|
api_key_enc = CASE WHEN EXCLUDED.api_key_enc = '' THEN provider_configs.api_key_enc ELSE EXCLUDED.api_key_enc END,
|
|
base_url = EXCLUDED.base_url,
|
|
resource_name = EXCLUDED.resource_name,
|
|
deployment_id = EXCLUDED.deployment_id,
|
|
api_version = EXCLUDED.api_version,
|
|
timeout_sec = EXCLUDED.timeout_sec,
|
|
max_conns = EXCLUDED.max_conns,
|
|
model_patterns= EXCLUDED.model_patterns,
|
|
is_active = EXCLUDED.is_active,
|
|
updated_at = NOW()
|
|
RETURNING id, tenant_id, provider, display_name, api_key_enc, base_url,
|
|
resource_name, deployment_id, api_version, timeout_sec, max_conns,
|
|
model_patterns, is_active, created_at, updated_at`,
|
|
tenantID, req.Provider, displayName(req), apiKeyEnc, req.BaseURL, req.ResourceName,
|
|
req.DeploymentID, apiVersion, timeoutSec, maxConns,
|
|
modelPatternsLiteral(req.ModelPatterns), isActive,
|
|
).Scan(&c.ID, &c.TenantID, &c.Provider, &c.DisplayName,
|
|
&apiKeyEnc, &c.BaseURL, &c.ResourceName, &c.DeploymentID, &c.APIVersion,
|
|
&c.TimeoutSec, &c.MaxConns, (*pqStringArray)(&storedPatterns),
|
|
&c.IsActive, &c.CreatedAt, &c.UpdatedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
c.ModelPatterns = storedPatterns
|
|
c.APIKeyMasked = maskAPIKey(apiKeyEnc, s.encryptor)
|
|
return &c, nil
|
|
}
|
|
|
|
func (s *providerStore) delete(ctx context.Context, id, tenantID string) error {
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM provider_configs WHERE id = $1 AND tenant_id = $2`, id, tenantID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return sql.ErrNoRows
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ─── HTTP handlers ────────────────────────────────────────────────────────────
|
|
|
|
func (h *Handler) listProviderConfigs(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := tenantFromCtx(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
if h.db == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
|
|
return
|
|
}
|
|
ps := newProviderStore(h.db, h.encryptor, h.logger)
|
|
configs, err := ps.list(r.Context(), tenantID)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError("failed to list providers: "+err.Error()))
|
|
return
|
|
}
|
|
if configs == nil {
|
|
configs = []ProviderConfig{}
|
|
}
|
|
|
|
// Enrich with circuit breaker status if available
|
|
type enriched struct {
|
|
ProviderConfig
|
|
State string `json:"state,omitempty"`
|
|
Failures int `json:"failures,omitempty"`
|
|
}
|
|
statuses := map[string]string{}
|
|
failures := map[string]int{}
|
|
if h.router != nil {
|
|
for _, s := range h.router.ProviderStatuses() {
|
|
statuses[s.Provider] = s.State
|
|
failures[s.Provider] = s.Failures
|
|
}
|
|
}
|
|
result := make([]enriched, len(configs))
|
|
for i, c := range configs {
|
|
result[i] = enriched{
|
|
ProviderConfig: c,
|
|
State: statuses[c.Provider],
|
|
Failures: failures[c.Provider],
|
|
}
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{"data": result})
|
|
}
|
|
|
|
func (h *Handler) createProviderConfig(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := tenantFromCtx(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
if h.db == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
|
|
return
|
|
}
|
|
var req upsertProviderRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON: "+err.Error()))
|
|
return
|
|
}
|
|
if !validProvider(req.Provider) {
|
|
apierror.WriteError(w, apierror.NewBadRequestError("provider must be one of: openai, anthropic, azure, mistral, ollama"))
|
|
return
|
|
}
|
|
|
|
ps := newProviderStore(h.db, h.encryptor, h.logger)
|
|
cfg, err := ps.create(r.Context(), tenantID, req)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError("failed to save provider: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
// Rebuild and hot-reload the adapter
|
|
h.reloadAdapter(r.Context(), cfg, req.APIKey)
|
|
h.logger.Info("provider config saved", zap.String("provider", cfg.Provider), zap.String("tenant", tenantID))
|
|
writeJSON(w, http.StatusCreated, cfg)
|
|
}
|
|
|
|
func (h *Handler) updateProviderConfig(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := tenantFromCtx(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
if h.db == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
|
|
return
|
|
}
|
|
id := chi.URLParam(r, "id")
|
|
var req upsertProviderRequest
|
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
|
apierror.WriteError(w, apierror.NewBadRequestError("invalid JSON: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
ps := newProviderStore(h.db, h.encryptor, h.logger)
|
|
// Fetch existing to get provider name if not supplied
|
|
existing, existingKeyEnc, err := ps.get(r.Context(), id, tenantID)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
|
|
return
|
|
}
|
|
if existing == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
|
|
return
|
|
}
|
|
if req.Provider == "" {
|
|
req.Provider = existing.Provider
|
|
}
|
|
|
|
// If no new API key provided, keep the existing encrypted one by passing empty string
|
|
// (the upsert handles this via the CASE statement)
|
|
decryptedKey := req.APIKey
|
|
if decryptedKey == "" && existingKeyEnc != "" {
|
|
// Decrypt existing key to pass to adapter rebuild
|
|
if h.encryptor != nil {
|
|
if dec, decErr := h.encryptor.Decrypt(existingKeyEnc); decErr == nil {
|
|
decryptedKey = dec
|
|
}
|
|
} else {
|
|
decryptedKey = existingKeyEnc
|
|
}
|
|
}
|
|
|
|
cfg, err := ps.create(r.Context(), tenantID, req) // upsert
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError("failed to update provider: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
h.reloadAdapter(r.Context(), cfg, decryptedKey)
|
|
h.logger.Info("provider config updated", zap.String("provider", cfg.Provider), zap.String("id", id))
|
|
writeJSON(w, http.StatusOK, cfg)
|
|
}
|
|
|
|
func (h *Handler) deleteProviderConfig(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := tenantFromCtx(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
if h.db == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
|
|
return
|
|
}
|
|
id := chi.URLParam(r, "id")
|
|
|
|
ps := newProviderStore(h.db, h.encryptor, h.logger)
|
|
existing, _, err := ps.get(r.Context(), id, tenantID)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
|
|
return
|
|
}
|
|
if existing == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
|
|
return
|
|
}
|
|
|
|
if err := ps.delete(r.Context(), id, tenantID); err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError("failed to delete provider: "+err.Error()))
|
|
return
|
|
}
|
|
|
|
// Remove adapter from router (config-based providers survive if they have an API key in env)
|
|
if h.adapterRouter != nil {
|
|
h.adapterRouter.RemoveAdapter(existing.Provider)
|
|
}
|
|
h.logger.Info("provider config deleted", zap.String("provider", existing.Provider))
|
|
w.WriteHeader(http.StatusNoContent)
|
|
}
|
|
|
|
func (h *Handler) testProviderConfig(w http.ResponseWriter, r *http.Request) {
|
|
tenantID, ok := tenantFromCtx(w, r)
|
|
if !ok {
|
|
return
|
|
}
|
|
if h.db == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_implemented", Message: "database not configured", HTTPStatus: http.StatusNotImplemented})
|
|
return
|
|
}
|
|
id := chi.URLParam(r, "id")
|
|
|
|
ps := newProviderStore(h.db, h.encryptor, h.logger)
|
|
existing, apiKeyEnc, err := ps.get(r.Context(), id, tenantID)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewUpstreamError(err.Error()))
|
|
return
|
|
}
|
|
if existing == nil {
|
|
apierror.WriteError(w, &apierror.APIError{Type: "not_found_error", Message: "provider not found", HTTPStatus: http.StatusNotFound})
|
|
return
|
|
}
|
|
|
|
apiKey := ""
|
|
if h.encryptor != nil && apiKeyEnc != "" {
|
|
if dec, decErr := h.encryptor.Decrypt(apiKeyEnc); decErr == nil {
|
|
apiKey = dec
|
|
}
|
|
} else {
|
|
apiKey = apiKeyEnc
|
|
}
|
|
|
|
adapter := buildAdapter(existing, apiKey)
|
|
if adapter == nil {
|
|
apierror.WriteError(w, apierror.NewBadRequestError("cannot build adapter for provider "+existing.Provider))
|
|
return
|
|
}
|
|
|
|
if err := adapter.HealthCheck(r.Context()); err != nil {
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{
|
|
"healthy": false,
|
|
"error": err.Error(),
|
|
})
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]interface{}{"healthy": true})
|
|
}
|
|
|
|
// ─── Adapter helpers ──────────────────────────────────────────────────────────
|
|
|
|
// reloadAdapter rebuilds the provider adapter from cfg and registers it on the router.
|
|
func (h *Handler) reloadAdapter(_ context.Context, cfg *ProviderConfig, apiKey string) {
|
|
if h.adapterRouter == nil {
|
|
return
|
|
}
|
|
adapter := buildAdapter(cfg, apiKey)
|
|
if adapter != nil {
|
|
h.adapterRouter.UpdateAdapter(cfg.Provider, adapter)
|
|
}
|
|
}
|
|
|
|
// BuildProviderAdapter constructs the correct provider.Adapter from a ProviderConfig.
|
|
// Exported so main.go can call it when loading provider configs from DB at startup.
|
|
func BuildProviderAdapter(cfg *ProviderConfig, apiKey string) provider.Adapter {
|
|
return buildAdapter(cfg, apiKey)
|
|
}
|
|
|
|
// buildAdapter constructs the correct provider.Adapter from a ProviderConfig.
|
|
func buildAdapter(cfg *ProviderConfig, apiKey string) provider.Adapter {
|
|
switch cfg.Provider {
|
|
case "openai":
|
|
return openai.New(openai.Config{
|
|
APIKey: apiKey,
|
|
BaseURL: orDefault(cfg.BaseURL, "https://api.openai.com/v1"),
|
|
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
|
|
MaxConns: orDefaultInt(cfg.MaxConns, 100),
|
|
})
|
|
case "anthropic":
|
|
return anthropic.New(anthropic.Config{
|
|
APIKey: apiKey,
|
|
BaseURL: orDefault(cfg.BaseURL, "https://api.anthropic.com/v1"),
|
|
Version: "2023-06-01",
|
|
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
|
|
MaxConns: orDefaultInt(cfg.MaxConns, 100),
|
|
})
|
|
case "azure":
|
|
return azure.New(azure.Config{
|
|
APIKey: apiKey,
|
|
ResourceName: cfg.ResourceName,
|
|
DeploymentID: cfg.DeploymentID,
|
|
APIVersion: orDefault(cfg.APIVersion, "2024-02-01"),
|
|
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
|
|
MaxConns: orDefaultInt(cfg.MaxConns, 100),
|
|
})
|
|
case "mistral":
|
|
return mistral.New(mistral.Config{
|
|
APIKey: apiKey,
|
|
BaseURL: orDefault(cfg.BaseURL, "https://api.mistral.ai/v1"),
|
|
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 30),
|
|
MaxConns: orDefaultInt(cfg.MaxConns, 100),
|
|
})
|
|
case "ollama":
|
|
return ollama.New(ollama.Config{
|
|
BaseURL: orDefault(cfg.BaseURL, "http://localhost:11434/v1"),
|
|
TimeoutSeconds: orDefaultInt(cfg.TimeoutSec, 120),
|
|
MaxConns: orDefaultInt(cfg.MaxConns, 10),
|
|
})
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ─── Helpers ──────────────────────────────────────────────────────────────────
|
|
|
|
func validProvider(p string) bool {
|
|
switch p {
|
|
case "openai", "anthropic", "azure", "mistral", "ollama":
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func defaultTimeout(provider string) int {
|
|
if provider == "ollama" {
|
|
return 120
|
|
}
|
|
return 30
|
|
}
|
|
|
|
func displayName(req upsertProviderRequest) string {
|
|
if req.DisplayName != "" {
|
|
return req.DisplayName
|
|
}
|
|
return strings.Title(req.Provider) //nolint:staticcheck
|
|
}
|
|
|
|
func encryptKey(key string, enc *crypto.Encryptor) string {
|
|
if key == "" {
|
|
return ""
|
|
}
|
|
if enc == nil {
|
|
return key // dev mode: store plaintext
|
|
}
|
|
encrypted, err := enc.Encrypt(key)
|
|
if err != nil {
|
|
return key // fallback to plaintext on error
|
|
}
|
|
return encrypted
|
|
}
|
|
|
|
func maskAPIKey(apiKeyEnc string, enc *crypto.Encryptor) string {
|
|
if apiKeyEnc == "" {
|
|
return ""
|
|
}
|
|
// Decrypt to get the real key, then mask it
|
|
plaintext := apiKeyEnc
|
|
if enc != nil {
|
|
if dec, err := enc.Decrypt(apiKeyEnc); err == nil {
|
|
plaintext = dec
|
|
}
|
|
}
|
|
if len(plaintext) <= 8 {
|
|
return strings.Repeat("•", len(plaintext))
|
|
}
|
|
return plaintext[:8] + "••••••••"
|
|
}
|
|
|
|
func orDefault(s, def string) string {
|
|
if s == "" {
|
|
return def
|
|
}
|
|
return s
|
|
}
|
|
|
|
func orDefaultInt(v, def int) int {
|
|
if v == 0 {
|
|
return def
|
|
}
|
|
return v
|
|
}
|
|
|
|
// pqStringArray is a simple []string scanner for PostgreSQL TEXT[] columns.
|
|
type pqStringArray []string
|
|
|
|
func (a *pqStringArray) Scan(src interface{}) error {
|
|
if src == nil {
|
|
*a = []string{}
|
|
return nil
|
|
}
|
|
// PostgreSQL returns TEXT[] as "{val1,val2,...}" string
|
|
s, ok := src.(string)
|
|
if !ok {
|
|
if b, ok := src.([]byte); ok {
|
|
s = string(b)
|
|
} else {
|
|
*a = []string{}
|
|
return nil
|
|
}
|
|
}
|
|
s = strings.TrimPrefix(s, "{")
|
|
s = strings.TrimSuffix(s, "}")
|
|
if s == "" {
|
|
*a = []string{}
|
|
return nil
|
|
}
|
|
parts := strings.Split(s, ",")
|
|
result := make([]string, 0, len(parts))
|
|
for _, p := range parts {
|
|
p = strings.Trim(p, `"`)
|
|
if p != "" {
|
|
result = append(result, p)
|
|
}
|
|
}
|
|
*a = result
|
|
return nil
|
|
}
|
|
|
|
// Value implements driver.Valuer for pqStringArray (for INSERT).
|
|
func (a pqStringArray) Value() (interface{}, error) {
|
|
if len(a) == 0 {
|
|
return "{}", nil
|
|
}
|
|
parts := make([]string, len(a))
|
|
for i, v := range a {
|
|
parts[i] = `"` + strings.ReplaceAll(v, `"`, `\"`) + `"`
|
|
}
|
|
return "{" + strings.Join(parts, ",") + "}", nil
|
|
}
|
|
|
|
// modelPatternsLiteral converts a []string to a PostgreSQL array literal string
|
|
// (e.g. {"a","b"}) to be used with $N::text[] in queries.
|
|
// This avoids driver.Valuer issues with the pgx stdlib adapter.
|
|
func modelPatternsLiteral(patterns []string) string {
|
|
if len(patterns) == 0 {
|
|
return "{}"
|
|
}
|
|
parts := make([]string, len(patterns))
|
|
for i, v := range patterns {
|
|
parts[i] = `"` + strings.ReplaceAll(v, `"`, `\"`) + `"`
|
|
}
|
|
return "{" + strings.Join(parts, ",") + "}"
|
|
}
|
|
|
|
// claimsFromRequest is a local helper (avoids import of middleware in this file).
|
|
func claimsFromRequest(r *http.Request) *middleware.UserClaims {
|
|
claims, _ := middleware.ClaimsFromContext(r.Context())
|
|
return claims
|
|
}
|