445 lines
15 KiB
Go
445 lines
15 KiB
Go
package router_test
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veylant/ia-gateway/internal/config"
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/router"
|
|
"github.com/veylant/ia-gateway/internal/routing"
|
|
)
|
|
|
|
// ─── Mock adapter ────────────────────────────────────────────────────────────
|
|
|
|
// mockAdapter records which methods were called and returns preset responses.
|
|
type mockAdapter struct {
|
|
name string
|
|
sendCalled bool
|
|
streamCalled bool
|
|
healthErr error
|
|
}
|
|
|
|
func (m *mockAdapter) Send(_ context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) {
|
|
m.sendCalled = true
|
|
return &provider.ChatResponse{
|
|
ID: "mock-" + m.name,
|
|
Model: req.Model,
|
|
Choices: []provider.Choice{{
|
|
Message: provider.Message{Role: "assistant", Content: "response from " + m.name},
|
|
}},
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockAdapter) Stream(_ context.Context, _ *provider.ChatRequest, w http.ResponseWriter) error {
|
|
m.streamCalled = true
|
|
_, _ = w.Write([]byte("data: [DONE]\n\n"))
|
|
return nil
|
|
}
|
|
|
|
func (m *mockAdapter) Validate(req *provider.ChatRequest) error {
|
|
if req.Model == "" {
|
|
return errors.New("model required")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *mockAdapter) HealthCheck(_ context.Context) error {
|
|
return m.healthErr
|
|
}
|
|
|
|
// ─── Helpers ─────────────────────────────────────────────────────────────────
|
|
|
|
func newAdapters() (map[string]provider.Adapter, map[string]*mockAdapter) {
|
|
mocks := map[string]*mockAdapter{
|
|
"openai": {name: "openai"},
|
|
"anthropic": {name: "anthropic"},
|
|
"mistral": {name: "mistral"},
|
|
"ollama": {name: "ollama"},
|
|
}
|
|
adapters := make(map[string]provider.Adapter, len(mocks))
|
|
for k, v := range mocks {
|
|
adapters[k] = v
|
|
}
|
|
return adapters, mocks
|
|
}
|
|
|
|
func adminCtx() context.Context {
|
|
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
|
|
UserID: "u1", TenantID: "t1", Roles: []string{"admin"},
|
|
})
|
|
}
|
|
|
|
func userCtx(allowedModels ...string) context.Context {
|
|
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
|
|
UserID: "u2", TenantID: "t1", Roles: []string{"user"},
|
|
})
|
|
}
|
|
|
|
func auditorCtx() context.Context {
|
|
return middleware.WithClaims(context.Background(), &middleware.UserClaims{
|
|
UserID: "u3", TenantID: "t1", Roles: []string{"auditor"},
|
|
})
|
|
}
|
|
|
|
func rbacConfig() *config.RBACConfig {
|
|
return &config.RBACConfig{
|
|
UserAllowedModels: []string{"gpt-4o-mini", "gpt-3.5-turbo", "mistral-small"},
|
|
AuditorCanComplete: false,
|
|
}
|
|
}
|
|
|
|
// ─── Routing ─────────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_Send_RoutesToOpenAI(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["openai"].sendCalled, "openai adapter should have been called")
|
|
assert.False(t, mocks["anthropic"].sendCalled)
|
|
}
|
|
|
|
func TestRouter_Send_RoutesToAnthropic(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "claude-3-5-sonnet-20241022",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["anthropic"].sendCalled, "anthropic adapter should have been called")
|
|
assert.False(t, mocks["openai"].sendCalled)
|
|
}
|
|
|
|
func TestRouter_Send_RoutesToMistral(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "mistral-medium",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["mistral"].sendCalled)
|
|
}
|
|
|
|
func TestRouter_Send_RoutesToOllama_Llama(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "llama3",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["ollama"].sendCalled)
|
|
}
|
|
|
|
func TestRouter_Send_UnknownModel_FallsBackToOpenAI(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "some-unknown-model",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["openai"].sendCalled, "unknown model should fall back to openai")
|
|
}
|
|
|
|
// ─── RBAC ────────────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_Send_AdminCanAccessAll(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
for _, model := range []string{"gpt-4o", "claude-3-opus", "mistral-medium", "llama3"} {
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: model,
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
assert.NoError(t, err, "admin should access %q", model)
|
|
}
|
|
}
|
|
|
|
func TestRouter_Send_UserAllowedModel_Routed(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(userCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o-mini",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["openai"].sendCalled)
|
|
}
|
|
|
|
func TestRouter_Send_UserBlockedUnauthorizedModel(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(userCtx(), &provider.ChatRequest{
|
|
Model: "claude-3-opus",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "claude-3-opus")
|
|
}
|
|
|
|
func TestRouter_Send_AuditorBlocked(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(auditorCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o-mini",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "auditor")
|
|
}
|
|
|
|
// ─── Stream ──────────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_Stream_RoutesToAnthropic(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
rec := httptest.NewRecorder()
|
|
err := r.Stream(adminCtx(), &provider.ChatRequest{
|
|
Model: "claude-3-5-sonnet-20241022",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
}, rec)
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["anthropic"].streamCalled)
|
|
}
|
|
|
|
// ─── HealthCheck ─────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_HealthCheck_AllHealthy(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
err := r.HealthCheck(context.Background())
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestRouter_HealthCheck_OneUnhealthy(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
mocks["anthropic"].healthErr = errors.New("connection refused")
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
err := r.HealthCheck(context.Background())
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "anthropic")
|
|
}
|
|
|
|
// ─── Validate ────────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_Validate_DelegatestoAdapter(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
// openai adapter requires model; should delegate and return its error
|
|
err := r.Validate(&provider.ChatRequest{Model: ""})
|
|
require.Error(t, err)
|
|
}
|
|
|
|
func TestRouter_Validate_ValidRequest(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
err := r.Validate(&provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
// ─── Stream RBAC ─────────────────────────────────────────────────────────────
|
|
|
|
func TestRouter_Stream_AuditorBlocked(t *testing.T) {
|
|
adapters, _ := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop())
|
|
|
|
rec := httptest.NewRecorder()
|
|
err := r.Stream(auditorCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o-mini",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
}, rec)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "auditor")
|
|
}
|
|
|
|
// ─── No fallback configured ───────────────────────────────────────────────────
|
|
|
|
func TestRouter_Send_NoAdapters_ReturnsError(t *testing.T) {
|
|
r := router.New(map[string]provider.Adapter{}, rbacConfig(), zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.Error(t, err)
|
|
}
|
|
|
|
// ─── Engine integration ───────────────────────────────────────────────────────
|
|
|
|
// failAdapter always returns an error from Send.
|
|
type failAdapter struct{ *mockAdapter }
|
|
|
|
func (f *failAdapter) Send(_ context.Context, _ *provider.ChatRequest) (*provider.ChatResponse, error) {
|
|
return nil, errors.New("provider unavailable")
|
|
}
|
|
|
|
func newEngineWithRule(rule routing.RoutingRule) *routing.Engine {
|
|
store := routing.NewMemStore()
|
|
_, _ = store.Create(context.Background(), rule)
|
|
return routing.New(store, 30*time.Second, zap.NewNop())
|
|
}
|
|
|
|
// TestRouter_WithEngine_SensitivityRoutesToOllama verifies that when the engine
|
|
// matches a critical-sensitivity rule, the request goes to ollama not openai.
|
|
func TestRouter_WithEngine_SensitivityRoutesToOllama(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
engine := newEngineWithRule(routing.RoutingRule{
|
|
TenantID: "t1",
|
|
Name: "critical → ollama",
|
|
Priority: 10,
|
|
IsEnabled: true,
|
|
Conditions: []routing.Condition{
|
|
{Field: "request.sensitivity", Operator: "gte", Value: "high"},
|
|
},
|
|
Action: routing.Action{Provider: "ollama"},
|
|
})
|
|
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
|
|
|
|
ctx := routing.WithSensitivity(adminCtx(), routing.SensitivityCritical)
|
|
_, err := r.Send(ctx, &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["ollama"].sendCalled, "critical sensitivity should route to ollama")
|
|
assert.False(t, mocks["openai"].sendCalled)
|
|
}
|
|
|
|
// TestRouter_WithEngine_NoMatch_FallsBackToStaticRules verifies that when no
|
|
// engine rule matches, static prefix routing is used (backward compat).
|
|
func TestRouter_WithEngine_NoMatch_FallsBackToStaticRules(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
// Engine rule requires "finance" dept — request has no dept.
|
|
engine := newEngineWithRule(routing.RoutingRule{
|
|
TenantID: "t1",
|
|
Name: "finance only",
|
|
Priority: 10,
|
|
IsEnabled: true,
|
|
Conditions: []routing.Condition{
|
|
{Field: "user.department", Operator: "eq", Value: "finance"},
|
|
},
|
|
Action: routing.Action{Provider: "ollama"},
|
|
})
|
|
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["openai"].sendCalled, "no engine match → static prefix rule → openai")
|
|
assert.False(t, mocks["ollama"].sendCalled)
|
|
}
|
|
|
|
// TestRouter_FallbackChain_PrimaryFails_TriesNext verifies that when the primary
|
|
// provider fails, the next provider in the fallback chain is tried.
|
|
func TestRouter_FallbackChain_PrimaryFails_TriesNext(t *testing.T) {
|
|
mocks := map[string]*mockAdapter{
|
|
"openai": {name: "openai"},
|
|
"anthropic": {name: "anthropic"},
|
|
"mistral": {name: "mistral"},
|
|
"ollama": {name: "ollama"},
|
|
}
|
|
adapters := make(map[string]provider.Adapter, len(mocks))
|
|
for k, v := range mocks {
|
|
adapters[k] = v
|
|
}
|
|
// Replace ollama with a failing adapter.
|
|
adapters["ollama"] = &failAdapter{mocks["ollama"]}
|
|
|
|
engine := newEngineWithRule(routing.RoutingRule{
|
|
TenantID: "t1",
|
|
Name: "ollama with openai fallback",
|
|
Priority: 10,
|
|
IsEnabled: true,
|
|
Conditions: []routing.Condition{},
|
|
Action: routing.Action{Provider: "ollama", FallbackProviders: []string{"openai"}},
|
|
})
|
|
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "gpt-4o",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["openai"].sendCalled, "fallback to openai after ollama fails")
|
|
}
|
|
|
|
// TestRouter_NilEngine_BackwardCompat verifies that a router without an engine
|
|
// behaves exactly as before Sprint 5 (static prefix rules).
|
|
func TestRouter_NilEngine_BackwardCompat(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
r := router.New(adapters, rbacConfig(), zap.NewNop()) // no engine
|
|
|
|
_, err := r.Send(adminCtx(), &provider.ChatRequest{
|
|
Model: "claude-3-opus",
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["anthropic"].sendCalled, "claude- prefix → anthropic")
|
|
}
|
|
|
|
// TestRouter_WithEngine_DepartmentRouting verifies department-based routing.
|
|
func TestRouter_WithEngine_DepartmentRouting(t *testing.T) {
|
|
adapters, mocks := newAdapters()
|
|
engine := newEngineWithRule(routing.RoutingRule{
|
|
TenantID: "t1",
|
|
Name: "engineering → anthropic",
|
|
Priority: 10,
|
|
IsEnabled: true,
|
|
Conditions: []routing.Condition{
|
|
{Field: "user.department", Operator: "eq", Value: "engineering"},
|
|
},
|
|
Action: routing.Action{Provider: "anthropic"},
|
|
})
|
|
r := router.NewWithEngine(adapters, rbacConfig(), engine, zap.NewNop())
|
|
|
|
ctx := middleware.WithClaims(context.Background(), &middleware.UserClaims{
|
|
UserID: "u1",
|
|
TenantID: "t1",
|
|
Roles: []string{"admin"},
|
|
Department: "engineering",
|
|
})
|
|
|
|
_, err := r.Send(ctx, &provider.ChatRequest{
|
|
Model: "gpt-4o", // model prefix → openai, but engine should override
|
|
Messages: []provider.Message{{Role: "user", Content: "hi"}},
|
|
})
|
|
require.NoError(t, err)
|
|
assert.True(t, mocks["anthropic"].sendCalled, "engineering dept should route to anthropic")
|
|
assert.False(t, mocks["openai"].sendCalled)
|
|
}
|