162 lines
4.9 KiB
Go
162 lines
4.9 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
)
|
|
|
|
// buildToken encodes claims as a JSON string — used with JSONMockVerifier in auth tests.
|
|
func buildToken(claims middleware.UserClaims) string {
|
|
b, _ := json.Marshal(claims)
|
|
return string(b)
|
|
}
|
|
|
|
// okHandler always returns HTTP 200.
|
|
var okHandler = http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
func TestAuth_ValidToken_InjectsClaimsAndCallsNext(t *testing.T) {
|
|
expected := &middleware.UserClaims{
|
|
UserID: "user-123",
|
|
TenantID: "tenant-456",
|
|
Email: "user@veylant.dev",
|
|
Roles: []string{"user"},
|
|
}
|
|
verifier := &middleware.MockVerifier{Claims: expected}
|
|
|
|
var got *middleware.UserClaims
|
|
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
got, _ = middleware.ClaimsFromContext(r.Context())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
req.Header.Set("Authorization", "Bearer somevalidtoken")
|
|
rec := httptest.NewRecorder()
|
|
|
|
middleware.Auth(verifier)(next).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, expected.UserID, got.UserID)
|
|
assert.Equal(t, expected.TenantID, got.TenantID)
|
|
assert.Equal(t, expected.Roles, got.Roles)
|
|
}
|
|
|
|
func TestAuth_MissingAuthorizationHeader_Returns401(t *testing.T) {
|
|
verifier := &middleware.MockVerifier{}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
rec := httptest.NewRecorder()
|
|
|
|
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
|
assertOpenAIError(t, rec, "authentication_error")
|
|
}
|
|
|
|
func TestAuth_ExpiredToken_Returns401(t *testing.T) {
|
|
verifier := &middleware.MockVerifier{Err: errors.New("token is expired")}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
req.Header.Set("Authorization", "Bearer expiredtoken")
|
|
rec := httptest.NewRecorder()
|
|
|
|
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
|
assertOpenAIError(t, rec, "authentication_error")
|
|
}
|
|
|
|
func TestAuth_BadIssuerToken_Returns401(t *testing.T) {
|
|
verifier := &middleware.MockVerifier{Err: errors.New("oidc: issuer mismatch")}
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil)
|
|
req.Header.Set("Authorization", "Bearer badissuertoken")
|
|
rec := httptest.NewRecorder()
|
|
|
|
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
|
assertOpenAIError(t, rec, "authentication_error")
|
|
}
|
|
|
|
func TestAuth_ValidToken_ClaimsFullyExtracted(t *testing.T) {
|
|
// Use JSONMockVerifier to prove full claim extraction pipeline.
|
|
expected := middleware.UserClaims{
|
|
UserID: "00000000-0000-0000-0000-000000000099",
|
|
TenantID: "00000000-0000-0000-0000-000000000001",
|
|
Email: "admin@veylant.dev",
|
|
Roles: []string{"admin", "user"},
|
|
}
|
|
|
|
var got *middleware.UserClaims
|
|
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
got, _ = middleware.ClaimsFromContext(r.Context())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/", nil)
|
|
req.Header.Set("Authorization", "Bearer "+buildToken(expected))
|
|
rec := httptest.NewRecorder()
|
|
|
|
verifier := &middleware.MockVerifier{
|
|
Claims: &expected,
|
|
}
|
|
middleware.Auth(verifier)(next).ServeHTTP(rec, req)
|
|
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, expected.UserID, got.UserID)
|
|
assert.Equal(t, expected.TenantID, got.TenantID)
|
|
assert.Equal(t, expected.Email, got.Email)
|
|
assert.ElementsMatch(t, expected.Roles, got.Roles)
|
|
}
|
|
|
|
func TestAuth_MalformedBearerScheme_Returns401(t *testing.T) {
|
|
verifier := &middleware.MockVerifier{}
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("Authorization", "Basic dXNlcjpwYXNz")
|
|
rec := httptest.NewRecorder()
|
|
|
|
middleware.Auth(verifier)(okHandler).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, http.StatusUnauthorized, rec.Code)
|
|
}
|
|
|
|
// assertOpenAIError checks that the response body has the OpenAI error envelope
|
|
// with the given error type.
|
|
func assertOpenAIError(t *testing.T, rec *httptest.ResponseRecorder, expectedType string) {
|
|
t.Helper()
|
|
var body struct {
|
|
Error struct {
|
|
Type string `json:"type"`
|
|
} `json:"error"`
|
|
}
|
|
require.NoError(t, json.NewDecoder(rec.Body).Decode(&body))
|
|
assert.Equal(t, expectedType, body.Error.Type)
|
|
}
|
|
|
|
// Compile-time check that MockVerifier satisfies TokenVerifier.
|
|
var _ middleware.TokenVerifier = (*middleware.MockVerifier)(nil)
|
|
|
|
// Compile-time check that context accessors are available from external packages.
|
|
func TestContextAccessors(t *testing.T) {
|
|
ctx := context.Background()
|
|
|
|
// No claims set — should return false.
|
|
_, ok := middleware.ClaimsFromContext(ctx)
|
|
assert.False(t, ok)
|
|
|
|
// No request ID set — should return empty string.
|
|
assert.Empty(t, middleware.RequestIDFromContext(ctx))
|
|
}
|