veylant/internal/middleware/ratelimit_test.go
2026-02-23 13:35:04 +01:00

79 lines
2.5 KiB
Go

package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
"github.com/veylant/ia-gateway/internal/middleware"
"github.com/veylant/ia-gateway/internal/ratelimit"
)
func makeRateLimitLimiter(burst int) *ratelimit.Limiter {
cfg := ratelimit.RateLimitConfig{
RequestsPerMin: 600,
BurstSize: burst,
UserRPM: 600,
UserBurst: burst,
IsEnabled: true,
}
return ratelimit.New(cfg, zap.NewNop())
}
// rlOkHandler returns 200 for every request (named to avoid redeclaration conflict with auth_test.go).
var rlOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestRateLimitMiddleware_AllowsWithinLimit(t *testing.T) {
limiter := makeRateLimitLimiter(5)
mw := middleware.RateLimit(limiter)(rlOkHandler)
for i := 0; i < 5; i++ {
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
// Inject claims into context.
ctx := middleware.WithClaims(req.Context(), &middleware.UserClaims{
UserID: "u1", TenantID: "t1",
})
req = req.WithContext(ctx)
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code, "request %d should be allowed", i+1)
}
}
func TestRateLimitMiddleware_Returns429WhenExceeded(t *testing.T) {
limiter := makeRateLimitLimiter(1) // burst = 1
mw := middleware.RateLimit(limiter)(rlOkHandler)
claims := &middleware.UserClaims{UserID: "u1", TenantID: "t1"}
// First request: allowed.
req1 := httptest.NewRequest(http.MethodPost, "/", nil)
req1 = req1.WithContext(middleware.WithClaims(req1.Context(), claims))
rec1 := httptest.NewRecorder()
mw.ServeHTTP(rec1, req1)
assert.Equal(t, http.StatusOK, rec1.Code)
// Second request: exceeded.
req2 := httptest.NewRequest(http.MethodPost, "/", nil)
req2 = req2.WithContext(middleware.WithClaims(req2.Context(), claims))
rec2 := httptest.NewRecorder()
mw.ServeHTTP(rec2, req2)
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
assert.Equal(t, "1", rec2.Header().Get("Retry-After"), "RFC 6585: 429 must include Retry-After header")
}
func TestRateLimitMiddleware_NoClaims_PassThrough(t *testing.T) {
limiter := makeRateLimitLimiter(1) // very restrictive, but no claims → pass
mw := middleware.RateLimit(limiter)(rlOkHandler)
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
}