79 lines
2.5 KiB
Go
79 lines
2.5 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"go.uber.org/zap"
|
|
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
"github.com/veylant/ia-gateway/internal/ratelimit"
|
|
)
|
|
|
|
func makeRateLimitLimiter(burst int) *ratelimit.Limiter {
|
|
cfg := ratelimit.RateLimitConfig{
|
|
RequestsPerMin: 600,
|
|
BurstSize: burst,
|
|
UserRPM: 600,
|
|
UserBurst: burst,
|
|
IsEnabled: true,
|
|
}
|
|
return ratelimit.New(cfg, zap.NewNop())
|
|
}
|
|
|
|
// rlOkHandler returns 200 for every request (named to avoid redeclaration conflict with auth_test.go).
|
|
var rlOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
func TestRateLimitMiddleware_AllowsWithinLimit(t *testing.T) {
|
|
limiter := makeRateLimitLimiter(5)
|
|
mw := middleware.RateLimit(limiter)(rlOkHandler)
|
|
|
|
for i := 0; i < 5; i++ {
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
// Inject claims into context.
|
|
ctx := middleware.WithClaims(req.Context(), &middleware.UserClaims{
|
|
UserID: "u1", TenantID: "t1",
|
|
})
|
|
req = req.WithContext(ctx)
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code, "request %d should be allowed", i+1)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware_Returns429WhenExceeded(t *testing.T) {
|
|
limiter := makeRateLimitLimiter(1) // burst = 1
|
|
mw := middleware.RateLimit(limiter)(rlOkHandler)
|
|
|
|
claims := &middleware.UserClaims{UserID: "u1", TenantID: "t1"}
|
|
|
|
// First request: allowed.
|
|
req1 := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
req1 = req1.WithContext(middleware.WithClaims(req1.Context(), claims))
|
|
rec1 := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec1, req1)
|
|
assert.Equal(t, http.StatusOK, rec1.Code)
|
|
|
|
// Second request: exceeded.
|
|
req2 := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
req2 = req2.WithContext(middleware.WithClaims(req2.Context(), claims))
|
|
rec2 := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec2, req2)
|
|
assert.Equal(t, http.StatusTooManyRequests, rec2.Code)
|
|
assert.Equal(t, "1", rec2.Header().Get("Retry-After"), "RFC 6585: 429 must include Retry-After header")
|
|
}
|
|
|
|
func TestRateLimitMiddleware_NoClaims_PassThrough(t *testing.T) {
|
|
limiter := makeRateLimitLimiter(1) // very restrictive, but no claims → pass
|
|
mw := middleware.RateLimit(limiter)(rlOkHandler)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/healthz", nil)
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
}
|