veylant/internal/router/router_test.go
2026-02-23 13:35:04 +01:00

445 lines
15 KiB
Go

package router_test
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"github.com/veylant/ia-gateway/internal/config"
"github.com/veylant/ia-gateway/internal/middleware"
"github.com/veylant/ia-gateway/internal/provider"
"github.com/veylant/ia-gateway/internal/router"
"github.com/veylant/ia-gateway/internal/routing"
)
// ─── Mock adapter ────────────────────────────────────────────────────────────
// mockAdapter records which methods were called and returns preset responses.
type mockAdapter struct {
name string
sendCalled bool
streamCalled bool
healthErr error
}
func (m *mockAdapter) Send(_ context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) {
m.sendCalled = true
return &provider.ChatResponse{
ID: "mock-" + m.name,
Model: req.Model,
Choices: []provider.Choice{{
Message: provider.Message{Role: "assistant", Content: "response from " + m.name},
}},
}, nil
}
func (m *mockAdapter) Stream(_ context.Context, _ *provider.ChatRequest, w http.ResponseWriter) error {
m.streamCalled = true
_, _ = w.Write([]byte("data: [DONE]\n\n"))
return nil
}
func (m *mockAdapter) Validate(req *provider.ChatRequest) error {
if req.Model == "" {
return errors.New("model required")
}
return nil
}
func (m *mockAdapter) HealthCheck(_ context.Context) error {
return m.healthErr
}
// ─── Helpers ─────────────────────────────────────────────────────────────────
func newAdapters() (map[string]provider.Adapter, map[string]*mockAdapter) {
mocks := map[string]*mockAdapter{
"openai": {name: "openai"},
"anthropic": {name: "anthropic"},
"mistral": {name: "mistral"},
"ollama": {name: "ollama"},
}
adapters := make(map[string]provider.Adapter, len(mocks))
for k, v := range mocks {
adapters[k] = v
}
return adapters, mocks
}
func adminCtx() context.Context {
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
UserID: "u1", TenantID: "t1", Roles: []string{"admin"},
})
}
func userCtx(allowedModels ...string) context.Context {
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
UserID: "u2", TenantID: "t1", Roles: []string{"user"},
})
}
func auditorCtx() context.Context {
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
UserID: "u3", TenantID: "t1", Roles: []string{"auditor"},
})
}
func rbacConfig() *config.RBACConfig {
return &config.RBACConfig{
UserAllowedModels: []string{"gpt-4o-mini", "gpt-3.5-turbo", "mistral-small"},
AuditorCanComplete: false,
}
}
// ─── Routing ─────────────────────────────────────────────────────────────────
func TestRouter_Send_RoutesToOpenAI(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["openai"].sendCalled, "openai adapter should have been called")
assert.False(t, mocks["anthropic"].sendCalled)
}
func TestRouter_Send_RoutesToAnthropic(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "claude-3-5-sonnet-20241022",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["anthropic"].sendCalled, "anthropic adapter should have been called")
assert.False(t, mocks["openai"].sendCalled)
}
func TestRouter_Send_RoutesToMistral(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "mistral-medium",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["mistral"].sendCalled)
}
func TestRouter_Send_RoutesToOllama_Llama(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "llama3",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["ollama"].sendCalled)
}
func TestRouter_Send_UnknownModel_FallsBackToOpenAI(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "some-unknown-model",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["openai"].sendCalled, "unknown model should fall back to openai")
}
// ─── RBAC ────────────────────────────────────────────────────────────────────
func TestRouter_Send_AdminCanAccessAll(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
for _, model := range []string{"gpt-4o", "claude-3-opus", "mistral-medium", "llama3"} {
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: model,
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
assert.NoError(t, err, "admin should access %q", model)
}
}
func TestRouter_Send_UserAllowedModel_Routed(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(userCtx(), &provider.ChatRequest{
Model: "gpt-4o-mini",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["openai"].sendCalled)
}
func TestRouter_Send_UserBlockedUnauthorizedModel(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(userCtx(), &provider.ChatRequest{
Model: "claude-3-opus",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "claude-3-opus")
}
func TestRouter_Send_AuditorBlocked(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
_, err := r.Send(auditorCtx(), &provider.ChatRequest{
Model: "gpt-4o-mini",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.Error(t, err)
assert.Contains(t, err.Error(), "auditor")
}
// ─── Stream ──────────────────────────────────────────────────────────────────
func TestRouter_Stream_RoutesToAnthropic(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
rec := httptest.NewRecorder()
err := r.Stream(adminCtx(), &provider.ChatRequest{
Model: "claude-3-5-sonnet-20241022",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
}, rec)
require.NoError(t, err)
assert.True(t, mocks["anthropic"].streamCalled)
}
// ─── HealthCheck ─────────────────────────────────────────────────────────────
func TestRouter_HealthCheck_AllHealthy(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
err := r.HealthCheck(context.Background())
require.NoError(t, err)
}
func TestRouter_HealthCheck_OneUnhealthy(t *testing.T) {
adapters, mocks := newAdapters()
mocks["anthropic"].healthErr = errors.New("connection refused")
r := router.New(adapters, rbacConfig(), zap.NewNop())
err := r.HealthCheck(context.Background())
require.Error(t, err)
assert.Contains(t, err.Error(), "anthropic")
}
// ─── Validate ────────────────────────────────────────────────────────────────
func TestRouter_Validate_DelegatestoAdapter(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
// openai adapter requires model; should delegate and return its error
err := r.Validate(&provider.ChatRequest{Model: ""})
require.Error(t, err)
}
func TestRouter_Validate_ValidRequest(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
err := r.Validate(&provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
}
// ─── Stream RBAC ─────────────────────────────────────────────────────────────
func TestRouter_Stream_AuditorBlocked(t *testing.T) {
adapters, _ := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop())
rec := httptest.NewRecorder()
err := r.Stream(auditorCtx(), &provider.ChatRequest{
Model: "gpt-4o-mini",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
}, rec)
require.Error(t, err)
assert.Contains(t, err.Error(), "auditor")
}
// ─── No fallback configured ───────────────────────────────────────────────────
func TestRouter_Send_NoAdapters_ReturnsError(t *testing.T) {
r := router.New(map[string]provider.Adapter{}, rbacConfig(), zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.Error(t, err)
}
// ─── Engine integration ───────────────────────────────────────────────────────
// failAdapter always returns an error from Send.
type failAdapter struct{ *mockAdapter }
func (f *failAdapter) Send(_ context.Context, _ *provider.ChatRequest) (*provider.ChatResponse, error) {
return nil, errors.New("provider unavailable")
}
func newEngineWithRule(rule routing.RoutingRule) *routing.Engine {
store := routing.NewMemStore()
_, _ = store.Create(context.Background(), rule)
return routing.New(store, 30*time.Second, zap.NewNop())
}
// TestRouter_WithEngine_SensitivityRoutesToOllama verifies that when the engine
// matches a critical-sensitivity rule, the request goes to ollama not openai.
func TestRouter_WithEngine_SensitivityRoutesToOllama(t *testing.T) {
adapters, mocks := newAdapters()
engine := newEngineWithRule(routing.RoutingRule{
TenantID: "t1",
Name: "critical → ollama",
Priority: 10,
IsEnabled: true,
Conditions: []routing.Condition{
{Field: "request.sensitivity", Operator: "gte", Value: "high"},
},
Action: routing.Action{Provider: "ollama"},
})
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
ctx := routing.WithSensitivity(adminCtx(), routing.SensitivityCritical)
_, err := r.Send(ctx, &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["ollama"].sendCalled, "critical sensitivity should route to ollama")
assert.False(t, mocks["openai"].sendCalled)
}
// TestRouter_WithEngine_NoMatch_FallsBackToStaticRules verifies that when no
// engine rule matches, static prefix routing is used (backward compat).
func TestRouter_WithEngine_NoMatch_FallsBackToStaticRules(t *testing.T) {
adapters, mocks := newAdapters()
// Engine rule requires "finance" dept — request has no dept.
engine := newEngineWithRule(routing.RoutingRule{
TenantID: "t1",
Name: "finance only",
Priority: 10,
IsEnabled: true,
Conditions: []routing.Condition{
{Field: "user.department", Operator: "eq", Value: "finance"},
},
Action: routing.Action{Provider: "ollama"},
})
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["openai"].sendCalled, "no engine match → static prefix rule → openai")
assert.False(t, mocks["ollama"].sendCalled)
}
// TestRouter_FallbackChain_PrimaryFails_TriesNext verifies that when the primary
// provider fails, the next provider in the fallback chain is tried.
func TestRouter_FallbackChain_PrimaryFails_TriesNext(t *testing.T) {
mocks := map[string]*mockAdapter{
"openai": {name: "openai"},
"anthropic": {name: "anthropic"},
"mistral": {name: "mistral"},
"ollama": {name: "ollama"},
}
adapters := make(map[string]provider.Adapter, len(mocks))
for k, v := range mocks {
adapters[k] = v
}
// Replace ollama with a failing adapter.
adapters["ollama"] = &failAdapter{mocks["ollama"]}
engine := newEngineWithRule(routing.RoutingRule{
TenantID: "t1",
Name: "ollama with openai fallback",
Priority: 10,
IsEnabled: true,
Conditions: []routing.Condition{},
Action: routing.Action{Provider: "ollama", FallbackProviders: []string{"openai"}},
})
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "gpt-4o",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["openai"].sendCalled, "fallback to openai after ollama fails")
}
// TestRouter_NilEngine_BackwardCompat verifies that a router without an engine
// behaves exactly as before Sprint 5 (static prefix rules).
func TestRouter_NilEngine_BackwardCompat(t *testing.T) {
adapters, mocks := newAdapters()
r := router.New(adapters, rbacConfig(), zap.NewNop()) // no engine
_, err := r.Send(adminCtx(), &provider.ChatRequest{
Model: "claude-3-opus",
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["anthropic"].sendCalled, "claude- prefix → anthropic")
}
// TestRouter_WithEngine_DepartmentRouting verifies department-based routing.
func TestRouter_WithEngine_DepartmentRouting(t *testing.T) {
adapters, mocks := newAdapters()
engine := newEngineWithRule(routing.RoutingRule{
TenantID: "t1",
Name: "engineering → anthropic",
Priority: 10,
IsEnabled: true,
Conditions: []routing.Condition{
{Field: "user.department", Operator: "eq", Value: "engineering"},
},
Action: routing.Action{Provider: "anthropic"},
})
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
ctx := middleware.WithClaims(context.Background(), &middleware.UserClaims{
UserID: "u1",
TenantID: "t1",
Roles: []string{"admin"},
Department: "engineering",
})
_, err := r.Send(ctx, &provider.ChatRequest{
Model: "gpt-4o", // model prefix → openai, but engine should override
Messages: []provider.Message{{Role: "user", Content: "hi"}},
})
require.NoError(t, err)
assert.True(t, mocks["anthropic"].sendCalled, "engineering dept should route to anthropic")
assert.False(t, mocks["openai"].sendCalled)
}