package middleware_test import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "go.uber.org/zap" "github.com/veylant/ia-gateway/internal/middleware" "github.com/veylant/ia-gateway/internal/ratelimit" ) func makeRateLimitLimiter(burst int) *ratelimit.Limiter { cfg := ratelimit.RateLimitConfig{ RequestsPerMin: 600, BurstSize: burst, UserRPM: 600, UserBurst: burst, IsEnabled: true, } return ratelimit.New(cfg, zap.NewNop()) } // rlOkHandler returns 200 for every request (named to avoid redeclaration conflict with auth_test.go). var rlOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) func TestRateLimitMiddleware_AllowsWithinLimit(t *testing.T) { limiter := makeRateLimitLimiter(5) mw := middleware.RateLimit(limiter)(rlOkHandler) for i := 0; i < 5; i++ { req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) // Inject claims into context. ctx := middleware.WithClaims(req.Context(), &middleware.UserClaims{ UserID: "u1", TenantID: "t1", }) req = req.WithContext(ctx) rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code, "request %d should be allowed", i+1) } } func TestRateLimitMiddleware_Returns429WhenExceeded(t *testing.T) { limiter := makeRateLimitLimiter(1) // burst = 1 mw := middleware.RateLimit(limiter)(rlOkHandler) claims := &middleware.UserClaims{UserID: "u1", TenantID: "t1"} // First request: allowed. req1 := httptest.NewRequest(http.MethodPost, "/", nil) req1 = req1.WithContext(middleware.WithClaims(req1.Context(), claims)) rec1 := httptest.NewRecorder() mw.ServeHTTP(rec1, req1) assert.Equal(t, http.StatusOK, rec1.Code) // Second request: exceeded. req2 := httptest.NewRequest(http.MethodPost, "/", nil) req2 = req2.WithContext(middleware.WithClaims(req2.Context(), claims)) rec2 := httptest.NewRecorder() mw.ServeHTTP(rec2, req2) assert.Equal(t, http.StatusTooManyRequests, rec2.Code) assert.Equal(t, "1", rec2.Header().Get("Retry-After"), "RFC 6585: 429 must include Retry-After header") } func TestRateLimitMiddleware_NoClaims_PassThrough(t *testing.T) { limiter := makeRateLimitLimiter(1) // very restrictive, but no claims → pass mw := middleware.RateLimit(limiter)(rlOkHandler) req := httptest.NewRequest(http.MethodGet, "/healthz", nil) rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) }