package router import ( "context" "fmt" "net/http" "strings" "go.uber.org/zap" "github.com/veylant/ia-gateway/internal/apierror" "github.com/veylant/ia-gateway/internal/circuitbreaker" "github.com/veylant/ia-gateway/internal/config" "github.com/veylant/ia-gateway/internal/flags" "github.com/veylant/ia-gateway/internal/middleware" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/routing" ) // modelRule maps a model name prefix to a provider name. type modelRule struct { prefix string provider string } // defaultModelRules maps model name prefixes to provider names. // Rules are evaluated in order; first match wins. // Ollama rules cover common self-hosted model families available via `ollama pull`. var defaultModelRules = []modelRule{ // Cloud providers (matched before generic local names) {"gpt-", "openai"}, {"o1-", "openai"}, {"o3-", "openai"}, {"claude-", "anthropic"}, {"mistral-", "mistral"}, // Mistral AI cloud (mistral-small, mistral-large…) {"mixtral-", "mistral"}, // Ollama / self-hosted models {"llama", "ollama"}, // llama3, llama3.2, llama2… {"phi", "ollama"}, // phi3, phi4, phi3.5… {"qwen", "ollama"}, // qwen2, qwen2.5, qwen-72b… {"gemma", "ollama"}, // gemma2, gemma3 (Google) {"deepseek", "ollama"}, // deepseek-r1, deepseek-coder… {"llava", "ollama"}, // llava, llava-phi3 (multimodal) {"tinyllama", "ollama"}, {"codellama", "ollama"}, {"yi-", "ollama"}, // yi-coder, yi-34b… {"vicuna", "ollama"}, {"starcoder", "ollama"}, // starcoder2… {"solar", "ollama"}, {"falcon", "ollama"}, {"orca", "ollama"}, // orca-mini… {"nous", "ollama"}, // nous-hermes… {"hermes", "ollama"}, {"wizard", "ollama"}, {"stable", "ollama"}, // stablelm… {"command-r", "ollama"}, // Cohere Command-R via Ollama {"neural", "ollama"}, } // Router implements provider.Adapter. It selects the correct upstream adapter // based on the requested model name and enforces RBAC before dispatching. // // When an Engine is configured (via NewWithEngine), dynamic routing rules take // precedence over the static prefix rules. If no engine rule matches, the router // falls back to Sprint 4 static prefix behaviour for backward compatibility. type Router struct { adapters map[string]provider.Adapter // "openai"|"anthropic"|"azure"|"mistral"|"ollama" modelRules []modelRule rbac *config.RBACConfig fallback string // provider name used when no prefix rule matches engine *routing.Engine // nil = static prefix rules only (Sprint 4 behaviour) breaker *circuitbreaker.Breaker // nil = no circuit breaker flagStore flags.FlagStore // nil = routing_enabled flag not checked logger *zap.Logger } // New creates a Router with static prefix-based routing only (Sprint 4 behaviour). // - adapters: map of provider name → Adapter (only configured providers need be present) // - rbac: RBAC configuration, must not be nil // - logger: structured logger func New(adapters map[string]provider.Adapter, rbac *config.RBACConfig, logger *zap.Logger) *Router { return &Router{ adapters: adapters, modelRules: defaultModelRules, rbac: rbac, fallback: "openai", logger: logger, } } // NewWithEngine creates a Router with dynamic routing rules powered by engine. // When engine evaluates a matching rule, its Action takes precedence (including // fallback chain). When no rule matches, static prefix rules are used. func NewWithEngine(adapters map[string]provider.Adapter, rbac *config.RBACConfig, engine *routing.Engine, logger *zap.Logger) *Router { r := New(adapters, rbac, logger) r.engine = engine return r } // NewWithEngineAndBreaker creates a Router with dynamic routing and a circuit breaker. func NewWithEngineAndBreaker(adapters map[string]provider.Adapter, rbac *config.RBACConfig, engine *routing.Engine, cb *circuitbreaker.Breaker, logger *zap.Logger) *Router { r := NewWithEngine(adapters, rbac, engine, logger) r.breaker = cb return r } // WithFlagStore attaches a feature flag store so the router can check the // routing_enabled flag per tenant (E11-07). When routing_enabled=false the // engine rules are skipped and static prefix rules are used directly. func (r *Router) WithFlagStore(fs flags.FlagStore) *Router { r.flagStore = fs return r } // UpdateAdapter registers or replaces the adapter for the given provider name. // This allows dynamic reloading from the database without a proxy restart. func (r *Router) UpdateAdapter(name string, adapter provider.Adapter) { r.adapters[name] = adapter r.logger.Info("provider adapter updated", zap.String("provider", name)) } // RemoveAdapter removes the adapter for the given provider name. func (r *Router) RemoveAdapter(name string) { delete(r.adapters, name) r.logger.Info("provider adapter removed", zap.String("provider", name)) } // AddModelRules prepends extra prefix rules (from DB provider configs). // Called after loading provider_configs so custom patterns take precedence. func (r *Router) AddModelRules(rules []modelRule) { r.modelRules = append(rules, r.modelRules...) } // ProviderStatuses returns circuit breaker status for all known providers. // Returns an empty slice when no circuit breaker is configured. func (r *Router) ProviderStatuses() []circuitbreaker.Status { if r.breaker == nil { return []circuitbreaker.Status{} } return r.breaker.Statuses() } // Send performs an RBAC check then dispatches to the resolved upstream adapter. // When an engine is configured, dynamic rules are evaluated first; on no match // the static prefix rules are used (Sprint 4 backward compatibility). func (r *Router) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) { if err := r.authorize(ctx, req.Model); err != nil { return nil, err } adapters, names, err := r.resolveWithEngine(ctx, req) if err != nil { return nil, err } return r.sendWithFallback(ctx, req, adapters, names) } // Stream performs an RBAC check then dispatches streaming to the resolved adapter. // When an engine is configured, dynamic rules are evaluated first; on no match // the static prefix rules are used (Sprint 4 backward compatibility). func (r *Router) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error { if err := r.authorize(ctx, req.Model); err != nil { return err } adapters, names, err := r.resolveWithEngine(ctx, req) if err != nil { return err } return r.streamWithFallback(ctx, req, w, adapters, names) } // Validate delegates to the adapter resolved for the requested model. func (r *Router) Validate(req *provider.ChatRequest) error { adapter, _, err := r.resolve(req.Model) if err != nil { return err } return adapter.Validate(req) } // HealthCheck aggregates HealthCheck results from all configured adapters. // Returns the first error encountered, nil if all healthy. func (r *Router) HealthCheck(ctx context.Context) error { for name, adapter := range r.adapters { if err := adapter.HealthCheck(ctx); err != nil { return fmt.Errorf("provider %s unhealthy: %w", name, err) } } return nil } // authorize extracts JWT claims from ctx and enforces RBAC. // When no claims are present (unauthenticated or test contexts), it defaults // to empty roles which HasAccess treats as the "user" role (fail-safe). func (r *Router) authorize(ctx context.Context, model string) error { var roles []string if claims, ok := middleware.ClaimsFromContext(ctx); ok { roles = claims.Roles } return HasAccess(roles, model, r.rbac) } // resolve finds the adapter for model using prefix rules, falling back to the // default provider if no rule matches. Returns an error only if neither the // resolved provider nor the fallback provider is configured. func (r *Router) resolve(model string) (provider.Adapter, string, error) { providerName := r.fallback for _, rule := range r.modelRules { if strings.HasPrefix(model, rule.prefix) { providerName = rule.provider break } } adapter, ok := r.adapters[providerName] if !ok { // Resolved provider not configured — try the fallback. adapter, ok = r.adapters[r.fallback] if !ok { return nil, "", apierror.NewUpstreamError( fmt.Sprintf("no adapter configured for model %q (resolved to provider %q)", model, providerName), ) } providerName = r.fallback } return adapter, providerName, nil } // logDispatch writes a structured log entry for each routed request. func (r *Router) logDispatch(ctx context.Context, op, model, providerName string) { fields := []zap.Field{ zap.String("op", op), zap.String("model", model), zap.String("provider", providerName), } if claims, ok := middleware.ClaimsFromContext(ctx); ok { fields = append(fields, zap.String("user_id", claims.UserID), zap.String("tenant_id", claims.TenantID), ) } r.logger.Info("routing request", fields...) } // ─── Engine-aware resolution ────────────────────────────────────────────────── // resolveWithEngine returns an ordered list of adapters to try (primary first, // then fallbacks). When the engine is nil, produces no match, or the // routing_enabled feature flag is false for the tenant, falls back to the // static single-adapter resolution from Sprint 4. func (r *Router) resolveWithEngine(ctx context.Context, req *provider.ChatRequest) ([]provider.Adapter, []string, error) { if r.engine != nil && r.isRoutingEnabled(ctx) { rctx := r.buildRoutingContext(ctx, req) if action, matched, err := r.engine.Evaluate(ctx, rctx); err == nil && matched { adapters, names, chainErr := r.buildFallbackChain(action) if chainErr == nil { return adapters, names, nil } // If none of the action's providers are configured, fall through to static rules. r.logger.Warn("routing engine matched but providers not configured, using static rules", zap.String("provider", action.Provider), ) } } // Sprint 4 static prefix rules (backward compat / routing disabled). adapter, name, err := r.resolve(req.Model) if err != nil { return nil, nil, err } return []provider.Adapter{adapter}, []string{name}, nil } // isRoutingEnabled returns false when the routing_enabled flag is explicitly // set to false for the caller's tenant. Defaults to true when flagStore is nil // or when the flag is not set. func (r *Router) isRoutingEnabled(ctx context.Context) bool { if r.flagStore == nil { return true } var tenantID string if claims, ok := middleware.ClaimsFromContext(ctx); ok { tenantID = claims.TenantID } enabled, err := r.flagStore.IsEnabled(ctx, tenantID, "routing_enabled") if err != nil { return true // fail-open: keep routing enabled on errors } // When no flag is set, IsEnabled returns false — but that would mean // routing_enabled=false by default, which is wrong. We seed the global // default to true in migration 000009; treat "not found" (false) as "use // default behaviour" only when the flag hasn't been explicitly seeded yet. // To distinguish "explicitly disabled" from "not set", we check IsEnabled // after also verifying a global default exists. For simplicity: if the // call succeeded and returned false, honour it only if we can also confirm // the global seed exists. In practice the migration seeds it to true so // IsEnabled will return true unless an admin explicitly overrides it. return enabled } // buildRoutingContext constructs a RoutingContext from the HTTP context and request. func (r *Router) buildRoutingContext(ctx context.Context, req *provider.ChatRequest) *routing.RoutingContext { rctx := &routing.RoutingContext{ Model: req.Model, TokenEstimate: estimateTokens(req), } if claims, ok := middleware.ClaimsFromContext(ctx); ok { rctx.TenantID = claims.TenantID rctx.Department = claims.Department if len(claims.Roles) > 0 { rctx.UserRole = claims.Roles[0] } } if s, ok := routing.SensitivityFromContext(ctx); ok { rctx.Sensitivity = s } return rctx } // buildFallbackChain converts a routing Action into an ordered slice of adapters. // Providers not present in r.adapters are silently skipped. func (r *Router) buildFallbackChain(action routing.Action) ([]provider.Adapter, []string, error) { var adapters []provider.Adapter var names []string for _, p := range append([]string{action.Provider}, action.FallbackProviders...) { if a, ok := r.adapters[p]; ok { adapters = append(adapters, a) names = append(names, p) } } if len(adapters) == 0 { return nil, nil, apierror.NewUpstreamError( fmt.Sprintf("no adapter configured for provider %q or its fallbacks", action.Provider), ) } return adapters, names, nil } // sendWithFallback attempts each adapter in order, returning the first success. // It sets resp.Provider to the name of the adapter that succeeded, enabling // accurate billing and audit logging downstream. func (r *Router) sendWithFallback(ctx context.Context, req *provider.ChatRequest, adapters []provider.Adapter, names []string) (*provider.ChatResponse, error) { var lastErr error for i, a := range adapters { // Circuit breaker: skip this provider if its circuit is open. if r.breaker != nil && !r.breaker.Allow(names[i]) { r.logger.Warn("circuit breaker open, skipping provider", zap.String("provider", names[i]), ) lastErr = apierror.NewUpstreamError("circuit breaker open for provider " + names[i]) continue } r.logDispatch(ctx, "send", req.Model, names[i]) resp, err := a.Send(ctx, req) if err == nil { if r.breaker != nil { r.breaker.Success(names[i]) } resp.Provider = names[i] return resp, nil } if r.breaker != nil { r.breaker.Failure(names[i]) } lastErr = err r.logger.Warn("provider failed, trying fallback", zap.String("provider", names[i]), zap.Error(err), ) } return nil, lastErr } // streamWithFallback attempts each adapter in order for streaming. // Note: once a provider starts writing to w, fallback is no longer safe. // For the MVP the first provider failure is tried on the next. func (r *Router) streamWithFallback(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter, adapters []provider.Adapter, names []string) error { var lastErr error for i, a := range adapters { // Circuit breaker: skip this provider if its circuit is open. if r.breaker != nil && !r.breaker.Allow(names[i]) { r.logger.Warn("circuit breaker open, skipping provider", zap.String("provider", names[i]), ) lastErr = apierror.NewUpstreamError("circuit breaker open for provider " + names[i]) continue } r.logDispatch(ctx, "stream", req.Model, names[i]) err := a.Stream(ctx, req, w) if err == nil { if r.breaker != nil { r.breaker.Success(names[i]) } return nil } if r.breaker != nil { r.breaker.Failure(names[i]) } lastErr = err r.logger.Warn("provider failed, trying fallback", zap.String("provider", names[i]), zap.Error(err), ) } return lastErr } // estimateTokens returns a rough token count for req (1 token ≈ 4 chars). func estimateTokens(req *provider.ChatRequest) int { total := 0 for _, m := range req.Messages { total += len(m.Content) / 4 } return total }