151 lines
4.5 KiB
Go
151 lines
4.5 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
)
|
|
|
|
// TokenVerifier abstracts JWT verification so that the auth middleware can be
|
|
// tested without a live Keycloak instance.
|
|
type TokenVerifier interface {
|
|
// Verify validates rawToken and returns the extracted claims on success.
|
|
Verify(ctx context.Context, rawToken string) (*UserClaims, error)
|
|
}
|
|
|
|
// OIDCVerifier implements TokenVerifier using go-oidc against a Keycloak realm.
|
|
type OIDCVerifier struct {
|
|
verifier *oidc.IDTokenVerifier
|
|
}
|
|
|
|
// keycloakClaims mirrors the relevant claims in a Keycloak-issued JWT.
|
|
type keycloakClaims struct {
|
|
Sub string `json:"sub"`
|
|
Email string `json:"email"`
|
|
TenantID string `json:"tenant_id"`
|
|
Department string `json:"department"`
|
|
RealmAccess realmAccess `json:"realm_access"`
|
|
}
|
|
|
|
type realmAccess struct {
|
|
Roles []string `json:"roles"`
|
|
}
|
|
|
|
// NewOIDCVerifier creates an OIDCVerifier using the Keycloak issuer URL and
|
|
// the expected client ID (audience).
|
|
// issuerURL example: "http://localhost:8080/realms/veylant"
|
|
func NewOIDCVerifier(ctx context.Context, issuerURL, clientID string) (*OIDCVerifier, error) {
|
|
provider, err := oidc.NewProvider(ctx, issuerURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating OIDC provider for %s: %w", issuerURL, err)
|
|
}
|
|
|
|
v := provider.Verifier(&oidc.Config{
|
|
ClientID: clientID,
|
|
})
|
|
|
|
return &OIDCVerifier{verifier: v}, nil
|
|
}
|
|
|
|
// Verify implements TokenVerifier.
|
|
func (o *OIDCVerifier) Verify(ctx context.Context, rawToken string) (*UserClaims, error) {
|
|
idToken, err := o.verifier.Verify(ctx, rawToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("token verification failed: %w", err)
|
|
}
|
|
|
|
var kc keycloakClaims
|
|
if err := idToken.Claims(&kc); err != nil {
|
|
return nil, fmt.Errorf("extracting token claims: %w", err)
|
|
}
|
|
|
|
return &UserClaims{
|
|
UserID: kc.Sub,
|
|
TenantID: kc.TenantID,
|
|
Email: kc.Email,
|
|
Roles: kc.RealmAccess.Roles,
|
|
Department: kc.Department,
|
|
}, nil
|
|
}
|
|
|
|
// Auth returns a middleware that enforces JWT bearer authentication.
|
|
// On success the UserClaims are injected into the request context.
|
|
// On failure a 401 in OpenAI error format is returned immediately.
|
|
func Auth(verifier TokenVerifier) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
rawToken := extractBearerToken(r)
|
|
if rawToken == "" {
|
|
apierror.WriteError(w, apierror.NewAuthError("missing Authorization header"))
|
|
return
|
|
}
|
|
|
|
claims, err := verifier.Verify(r.Context(), rawToken)
|
|
if err != nil {
|
|
apierror.WriteError(w, apierror.NewAuthError("invalid or expired token"))
|
|
return
|
|
}
|
|
|
|
ctx := WithClaims(r.Context(), claims)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// extractBearerToken extracts the token value from "Authorization: Bearer <token>".
|
|
// Returns an empty string if the header is absent or malformed.
|
|
func extractBearerToken(r *http.Request) string {
|
|
h := r.Header.Get("Authorization")
|
|
if h == "" {
|
|
return ""
|
|
}
|
|
parts := strings.SplitN(h, " ", 2)
|
|
if len(parts) != 2 || !strings.EqualFold(parts[0], "bearer") {
|
|
return ""
|
|
}
|
|
return parts[1]
|
|
}
|
|
|
|
// MockVerifier is a test-only TokenVerifier that returns pre-configured
|
|
// claims or errors without any network call.
|
|
type MockVerifier struct {
|
|
Claims *UserClaims
|
|
Err error
|
|
}
|
|
|
|
// Verify implements TokenVerifier for tests.
|
|
func (m *MockVerifier) Verify(_ context.Context, _ string) (*UserClaims, error) {
|
|
return m.Claims, m.Err
|
|
}
|
|
|
|
// mockTokenClaims is a helper to decode a mock JWT payload for testing.
|
|
// It allows passing structured claims as a JSON-encoded token for the mock path.
|
|
func mockTokenClaims(rawToken string) (*UserClaims, error) {
|
|
var kc keycloakClaims
|
|
if err := json.Unmarshal([]byte(rawToken), &kc); err != nil {
|
|
return nil, fmt.Errorf("invalid mock token: %w", err)
|
|
}
|
|
return &UserClaims{
|
|
UserID: kc.Sub,
|
|
TenantID: kc.TenantID,
|
|
Email: kc.Email,
|
|
Roles: kc.RealmAccess.Roles,
|
|
Department: kc.Department,
|
|
}, nil
|
|
}
|
|
|
|
// JSONMockVerifier decodes a JSON-encoded UserClaims payload directly from the
|
|
// token string. Useful for integration-level unit tests.
|
|
type JSONMockVerifier struct{}
|
|
|
|
// Verify implements TokenVerifier by treating rawToken as a JSON-encoded UserClaims.
|
|
func (j *JSONMockVerifier) Verify(_ context.Context, rawToken string) (*UserClaims, error) {
|
|
return mockTokenClaims(rawToken)
|
|
}
|