veylant/internal/middleware/auth_test.go
2026-02-23 13:35:04 +01:00

162 lines
4.9 KiB
Go

package middleware_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/veylant/ia-gateway/internal/middleware"
)
// buildToken encodes claims as a JSON string — used with JSONMockVerifier in auth tests.
func buildToken(claims middleware.UserClaims) string {
b, _ := json.Marshal(claims)
return string(b)
}
// okHandler always returns HTTP 200.
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestAuth_ValidToken_InjectsClaimsAndCallsNext(t *testing.T) {
expected := &middleware.UserClaims{
UserID: "user-123",
TenantID: "tenant-456",
Email: "user@veylant.dev",
Roles: []string{"user"},
}
verifier := &middleware.MockVerifier{Claims: expected}
var got *middleware.UserClaims
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
got, _ = middleware.ClaimsFromContext(r.Context())
})
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("Authorization", "Bearer somevalidtoken")
rec := httptest.NewRecorder()
middleware.Auth(verifier)(next).ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
require.NotNil(t, got)
assert.Equal(t, expected.UserID, got.UserID)
assert.Equal(t, expected.TenantID, got.TenantID)
assert.Equal(t, expected.Roles, got.Roles)
}
func TestAuth_MissingAuthorizationHeader_Returns401(t *testing.T) {
verifier := &middleware.MockVerifier{}
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
rec := httptest.NewRecorder()
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assertOpenAIError(t, rec, "authentication_error")
}
func TestAuth_ExpiredToken_Returns401(t *testing.T) {
verifier := &middleware.MockVerifier{Err: errors.New("token is expired")}
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("Authorization", "Bearer expiredtoken")
rec := httptest.NewRecorder()
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assertOpenAIError(t, rec, "authentication_error")
}
func TestAuth_BadIssuerToken_Returns401(t *testing.T) {
verifier := &middleware.MockVerifier{Err: errors.New("oidc: issuer mismatch")}
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
req.Header.Set("Authorization", "Bearer badissuertoken")
rec := httptest.NewRecorder()
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
assertOpenAIError(t, rec, "authentication_error")
}
func TestAuth_ValidToken_ClaimsFullyExtracted(t *testing.T) {
// Use JSONMockVerifier to prove full claim extraction pipeline.
expected := middleware.UserClaims{
UserID: "00000000-0000-0000-0000-000000000099",
TenantID: "00000000-0000-0000-0000-000000000001",
Email: "admin@veylant.dev",
Roles: []string{"admin", "user"},
}
var got *middleware.UserClaims
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
got, _ = middleware.ClaimsFromContext(r.Context())
})
req := httptest.NewRequest(http.MethodPost, "/", nil)
req.Header.Set("Authorization", "Bearer "+buildToken(expected))
rec := httptest.NewRecorder()
verifier := &middleware.MockVerifier{
Claims: &expected,
}
middleware.Auth(verifier)(next).ServeHTTP(rec, req)
require.NotNil(t, got)
assert.Equal(t, expected.UserID, got.UserID)
assert.Equal(t, expected.TenantID, got.TenantID)
assert.Equal(t, expected.Email, got.Email)
assert.ElementsMatch(t, expected.Roles, got.Roles)
}
func TestAuth_MalformedBearerScheme_Returns401(t *testing.T) {
verifier := &middleware.MockVerifier{}
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
rec := httptest.NewRecorder()
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnauthorized, rec.Code)
}
// assertOpenAIError checks that the response body has the OpenAI error envelope
// with the given error type.
func assertOpenAIError(t *testing.T, rec *httptest.ResponseRecorder, expectedType string) {
t.Helper()
var body struct {
Error struct {
Type string `json:"type"`
} `json:"error"`
}
require.NoError(t, json.NewDecoder(rec.Body).Decode(&body))
assert.Equal(t, expectedType, body.Error.Type)
}
// Compile-time check that MockVerifier satisfies TokenVerifier.
var _ middleware.TokenVerifier = (*middleware.MockVerifier)(nil)
// Compile-time check that context accessors are available from external packages.
func TestContextAccessors(t *testing.T) {
ctx := context.Background()
// No claims set — should return false.
_, ok := middleware.ClaimsFromContext(ctx)
assert.False(t, ok)
// No request ID set — should return empty string.
assert.Empty(t, middleware.RequestIDFromContext(ctx))
}