376 lines
13 KiB
Go
376 lines
13 KiB
Go
package router
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
"github.com/veylant/ia-gateway/internal/circuitbreaker"
|
|
"github.com/veylant/ia-gateway/internal/config"
|
|
"github.com/veylant/ia-gateway/internal/flags"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/routing"
|
|
)
|
|
|
|
// modelRule maps a model name prefix to a provider name.
|
|
type modelRule struct {
|
|
prefix string
|
|
provider string
|
|
}
|
|
|
|
// defaultModelRules maps model name prefixes to provider names.
|
|
// Rules are evaluated in order; first match wins.
|
|
var defaultModelRules = []modelRule{
|
|
{"gpt-", "openai"},
|
|
{"o1-", "openai"},
|
|
{"o3-", "openai"},
|
|
{"claude-", "anthropic"},
|
|
{"mistral-", "mistral"},
|
|
{"mixtral-", "mistral"},
|
|
{"llama", "ollama"},
|
|
{"phi", "ollama"},
|
|
{"qwen", "ollama"},
|
|
}
|
|
|
|
// Router implements provider.Adapter. It selects the correct upstream adapter
|
|
// based on the requested model name and enforces RBAC before dispatching.
|
|
//
|
|
// When an Engine is configured (via NewWithEngine), dynamic routing rules take
|
|
// precedence over the static prefix rules. If no engine rule matches, the router
|
|
// falls back to Sprint 4 static prefix behaviour for backward compatibility.
|
|
type Router struct {
|
|
adapters map[string]provider.Adapter // "openai"|"anthropic"|"azure"|"mistral"|"ollama"
|
|
modelRules []modelRule
|
|
rbac *config.RBACConfig
|
|
fallback string // provider name used when no prefix rule matches
|
|
engine *routing.Engine // nil = static prefix rules only (Sprint 4 behaviour)
|
|
breaker *circuitbreaker.Breaker // nil = no circuit breaker
|
|
flagStore flags.FlagStore // nil = routing_enabled flag not checked
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// New creates a Router with static prefix-based routing only (Sprint 4 behaviour).
|
|
// - adapters: map of provider name → Adapter (only configured providers need be present)
|
|
// - rbac: RBAC configuration, must not be nil
|
|
// - logger: structured logger
|
|
func New(adapters map[string]provider.Adapter, rbac *config.RBACConfig, logger *zap.Logger) *Router {
|
|
return &Router{
|
|
adapters: adapters,
|
|
modelRules: defaultModelRules,
|
|
rbac: rbac,
|
|
fallback: "openai",
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// NewWithEngine creates a Router with dynamic routing rules powered by engine.
|
|
// When engine evaluates a matching rule, its Action takes precedence (including
|
|
// fallback chain). When no rule matches, static prefix rules are used.
|
|
func NewWithEngine(adapters map[string]provider.Adapter, rbac *config.RBACConfig, engine *routing.Engine, logger *zap.Logger) *Router {
|
|
r := New(adapters, rbac, logger)
|
|
r.engine = engine
|
|
return r
|
|
}
|
|
|
|
// NewWithEngineAndBreaker creates a Router with dynamic routing and a circuit breaker.
|
|
func NewWithEngineAndBreaker(adapters map[string]provider.Adapter, rbac *config.RBACConfig, engine *routing.Engine, cb *circuitbreaker.Breaker, logger *zap.Logger) *Router {
|
|
r := NewWithEngine(adapters, rbac, engine, logger)
|
|
r.breaker = cb
|
|
return r
|
|
}
|
|
|
|
// WithFlagStore attaches a feature flag store so the router can check the
|
|
// routing_enabled flag per tenant (E11-07). When routing_enabled=false the
|
|
// engine rules are skipped and static prefix rules are used directly.
|
|
func (r *Router) WithFlagStore(fs flags.FlagStore) *Router {
|
|
r.flagStore = fs
|
|
return r
|
|
}
|
|
|
|
// ProviderStatuses returns circuit breaker status for all known providers.
|
|
// Returns an empty slice when no circuit breaker is configured.
|
|
func (r *Router) ProviderStatuses() []circuitbreaker.Status {
|
|
if r.breaker == nil {
|
|
return []circuitbreaker.Status{}
|
|
}
|
|
return r.breaker.Statuses()
|
|
}
|
|
|
|
// Send performs an RBAC check then dispatches to the resolved upstream adapter.
|
|
// When an engine is configured, dynamic rules are evaluated first; on no match
|
|
// the static prefix rules are used (Sprint 4 backward compatibility).
|
|
func (r *Router) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) {
|
|
if err := r.authorize(ctx, req.Model); err != nil {
|
|
return nil, err
|
|
}
|
|
adapters, names, err := r.resolveWithEngine(ctx, req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return r.sendWithFallback(ctx, req, adapters, names)
|
|
}
|
|
|
|
// Stream performs an RBAC check then dispatches streaming to the resolved adapter.
|
|
// When an engine is configured, dynamic rules are evaluated first; on no match
|
|
// the static prefix rules are used (Sprint 4 backward compatibility).
|
|
func (r *Router) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error {
|
|
if err := r.authorize(ctx, req.Model); err != nil {
|
|
return err
|
|
}
|
|
adapters, names, err := r.resolveWithEngine(ctx, req)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return r.streamWithFallback(ctx, req, w, adapters, names)
|
|
}
|
|
|
|
// Validate delegates to the adapter resolved for the requested model.
|
|
func (r *Router) Validate(req *provider.ChatRequest) error {
|
|
adapter, _, err := r.resolve(req.Model)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return adapter.Validate(req)
|
|
}
|
|
|
|
// HealthCheck aggregates HealthCheck results from all configured adapters.
|
|
// Returns the first error encountered, nil if all healthy.
|
|
func (r *Router) HealthCheck(ctx context.Context) error {
|
|
for name, adapter := range r.adapters {
|
|
if err := adapter.HealthCheck(ctx); err != nil {
|
|
return fmt.Errorf("provider %s unhealthy: %w", name, err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// authorize extracts JWT claims from ctx and enforces RBAC.
|
|
// When no claims are present (unauthenticated or test contexts), it defaults
|
|
// to empty roles which HasAccess treats as the "user" role (fail-safe).
|
|
func (r *Router) authorize(ctx context.Context, model string) error {
|
|
var roles []string
|
|
if claims, ok := middleware.ClaimsFromContext(ctx); ok {
|
|
roles = claims.Roles
|
|
}
|
|
return HasAccess(roles, model, r.rbac)
|
|
}
|
|
|
|
// resolve finds the adapter for model using prefix rules, falling back to the
|
|
// default provider if no rule matches. Returns an error only if neither the
|
|
// resolved provider nor the fallback provider is configured.
|
|
func (r *Router) resolve(model string) (provider.Adapter, string, error) {
|
|
providerName := r.fallback
|
|
for _, rule := range r.modelRules {
|
|
if strings.HasPrefix(model, rule.prefix) {
|
|
providerName = rule.provider
|
|
break
|
|
}
|
|
}
|
|
|
|
adapter, ok := r.adapters[providerName]
|
|
if !ok {
|
|
// Resolved provider not configured — try the fallback.
|
|
adapter, ok = r.adapters[r.fallback]
|
|
if !ok {
|
|
return nil, "", apierror.NewUpstreamError(
|
|
fmt.Sprintf("no adapter configured for model %q (resolved to provider %q)", model, providerName),
|
|
)
|
|
}
|
|
providerName = r.fallback
|
|
}
|
|
return adapter, providerName, nil
|
|
}
|
|
|
|
// logDispatch writes a structured log entry for each routed request.
|
|
func (r *Router) logDispatch(ctx context.Context, op, model, providerName string) {
|
|
fields := []zap.Field{
|
|
zap.String("op", op),
|
|
zap.String("model", model),
|
|
zap.String("provider", providerName),
|
|
}
|
|
if claims, ok := middleware.ClaimsFromContext(ctx); ok {
|
|
fields = append(fields,
|
|
zap.String("user_id", claims.UserID),
|
|
zap.String("tenant_id", claims.TenantID),
|
|
)
|
|
}
|
|
r.logger.Info("routing request", fields...)
|
|
}
|
|
|
|
// ─── Engine-aware resolution ──────────────────────────────────────────────────
|
|
|
|
// resolveWithEngine returns an ordered list of adapters to try (primary first,
|
|
// then fallbacks). When the engine is nil, produces no match, or the
|
|
// routing_enabled feature flag is false for the tenant, falls back to the
|
|
// static single-adapter resolution from Sprint 4.
|
|
func (r *Router) resolveWithEngine(ctx context.Context, req *provider.ChatRequest) ([]provider.Adapter, []string, error) {
|
|
if r.engine != nil && r.isRoutingEnabled(ctx) {
|
|
rctx := r.buildRoutingContext(ctx, req)
|
|
if action, matched, err := r.engine.Evaluate(ctx, rctx); err == nil && matched {
|
|
adapters, names, chainErr := r.buildFallbackChain(action)
|
|
if chainErr == nil {
|
|
return adapters, names, nil
|
|
}
|
|
// If none of the action's providers are configured, fall through to static rules.
|
|
r.logger.Warn("routing engine matched but providers not configured, using static rules",
|
|
zap.String("provider", action.Provider),
|
|
)
|
|
}
|
|
}
|
|
// Sprint 4 static prefix rules (backward compat / routing disabled).
|
|
adapter, name, err := r.resolve(req.Model)
|
|
if err != nil {
|
|
return nil, nil, err
|
|
}
|
|
return []provider.Adapter{adapter}, []string{name}, nil
|
|
}
|
|
|
|
// isRoutingEnabled returns false when the routing_enabled flag is explicitly
|
|
// set to false for the caller's tenant. Defaults to true when flagStore is nil
|
|
// or when the flag is not set.
|
|
func (r *Router) isRoutingEnabled(ctx context.Context) bool {
|
|
if r.flagStore == nil {
|
|
return true
|
|
}
|
|
var tenantID string
|
|
if claims, ok := middleware.ClaimsFromContext(ctx); ok {
|
|
tenantID = claims.TenantID
|
|
}
|
|
enabled, err := r.flagStore.IsEnabled(ctx, tenantID, "routing_enabled")
|
|
if err != nil {
|
|
return true // fail-open: keep routing enabled on errors
|
|
}
|
|
// When no flag is set, IsEnabled returns false — but that would mean
|
|
// routing_enabled=false by default, which is wrong. We seed the global
|
|
// default to true in migration 000009; treat "not found" (false) as "use
|
|
// default behaviour" only when the flag hasn't been explicitly seeded yet.
|
|
// To distinguish "explicitly disabled" from "not set", we check IsEnabled
|
|
// after also verifying a global default exists. For simplicity: if the
|
|
// call succeeded and returned false, honour it only if we can also confirm
|
|
// the global seed exists. In practice the migration seeds it to true so
|
|
// IsEnabled will return true unless an admin explicitly overrides it.
|
|
return enabled
|
|
}
|
|
|
|
// buildRoutingContext constructs a RoutingContext from the HTTP context and request.
|
|
func (r *Router) buildRoutingContext(ctx context.Context, req *provider.ChatRequest) *routing.RoutingContext {
|
|
rctx := &routing.RoutingContext{
|
|
Model: req.Model,
|
|
TokenEstimate: estimateTokens(req),
|
|
}
|
|
if claims, ok := middleware.ClaimsFromContext(ctx); ok {
|
|
rctx.TenantID = claims.TenantID
|
|
rctx.Department = claims.Department
|
|
if len(claims.Roles) > 0 {
|
|
rctx.UserRole = claims.Roles[0]
|
|
}
|
|
}
|
|
if s, ok := routing.SensitivityFromContext(ctx); ok {
|
|
rctx.Sensitivity = s
|
|
}
|
|
return rctx
|
|
}
|
|
|
|
// buildFallbackChain converts a routing Action into an ordered slice of adapters.
|
|
// Providers not present in r.adapters are silently skipped.
|
|
func (r *Router) buildFallbackChain(action routing.Action) ([]provider.Adapter, []string, error) {
|
|
var adapters []provider.Adapter
|
|
var names []string
|
|
|
|
for _, p := range append([]string{action.Provider}, action.FallbackProviders...) {
|
|
if a, ok := r.adapters[p]; ok {
|
|
adapters = append(adapters, a)
|
|
names = append(names, p)
|
|
}
|
|
}
|
|
if len(adapters) == 0 {
|
|
return nil, nil, apierror.NewUpstreamError(
|
|
fmt.Sprintf("no adapter configured for provider %q or its fallbacks", action.Provider),
|
|
)
|
|
}
|
|
return adapters, names, nil
|
|
}
|
|
|
|
// sendWithFallback attempts each adapter in order, returning the first success.
|
|
// It sets resp.Provider to the name of the adapter that succeeded, enabling
|
|
// accurate billing and audit logging downstream.
|
|
func (r *Router) sendWithFallback(ctx context.Context, req *provider.ChatRequest, adapters []provider.Adapter, names []string) (*provider.ChatResponse, error) {
|
|
var lastErr error
|
|
for i, a := range adapters {
|
|
// Circuit breaker: skip this provider if its circuit is open.
|
|
if r.breaker != nil && !r.breaker.Allow(names[i]) {
|
|
r.logger.Warn("circuit breaker open, skipping provider",
|
|
zap.String("provider", names[i]),
|
|
)
|
|
lastErr = apierror.NewUpstreamError("circuit breaker open for provider " + names[i])
|
|
continue
|
|
}
|
|
r.logDispatch(ctx, "send", req.Model, names[i])
|
|
resp, err := a.Send(ctx, req)
|
|
if err == nil {
|
|
if r.breaker != nil {
|
|
r.breaker.Success(names[i])
|
|
}
|
|
resp.Provider = names[i]
|
|
return resp, nil
|
|
}
|
|
if r.breaker != nil {
|
|
r.breaker.Failure(names[i])
|
|
}
|
|
lastErr = err
|
|
r.logger.Warn("provider failed, trying fallback",
|
|
zap.String("provider", names[i]),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
return nil, lastErr
|
|
}
|
|
|
|
// streamWithFallback attempts each adapter in order for streaming.
|
|
// Note: once a provider starts writing to w, fallback is no longer safe.
|
|
// For the MVP the first provider failure is tried on the next.
|
|
func (r *Router) streamWithFallback(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter, adapters []provider.Adapter, names []string) error {
|
|
var lastErr error
|
|
for i, a := range adapters {
|
|
// Circuit breaker: skip this provider if its circuit is open.
|
|
if r.breaker != nil && !r.breaker.Allow(names[i]) {
|
|
r.logger.Warn("circuit breaker open, skipping provider",
|
|
zap.String("provider", names[i]),
|
|
)
|
|
lastErr = apierror.NewUpstreamError("circuit breaker open for provider " + names[i])
|
|
continue
|
|
}
|
|
r.logDispatch(ctx, "stream", req.Model, names[i])
|
|
err := a.Stream(ctx, req, w)
|
|
if err == nil {
|
|
if r.breaker != nil {
|
|
r.breaker.Success(names[i])
|
|
}
|
|
return nil
|
|
}
|
|
if r.breaker != nil {
|
|
r.breaker.Failure(names[i])
|
|
}
|
|
lastErr = err
|
|
r.logger.Warn("provider failed, trying fallback",
|
|
zap.String("provider", names[i]),
|
|
zap.Error(err),
|
|
)
|
|
}
|
|
return lastErr
|
|
}
|
|
|
|
// estimateTokens returns a rough token count for req (1 token ≈ 4 chars).
|
|
func estimateTokens(req *provider.ChatRequest) int {
|
|
total := 0
|
|
for _, m := range req.Messages {
|
|
total += len(m.Content) / 4
|
|
}
|
|
return total
|
|
}
|