veylant/internal/middleware/cors_test.go
2026-02-23 13:35:04 +01:00

81 lines
3.2 KiB
Go

package middleware_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/veylant/ia-gateway/internal/middleware"
)
var corsOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
func TestCORS_NoOrigins_NoHeaders(t *testing.T) {
mw := middleware.CORS(nil)(corsOkHandler)
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
req.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
}
func TestCORS_AllowedOrigin_SetsHeader(t *testing.T) {
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
req.Header.Set("Origin", "http://localhost:3000")
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", rec.Header().Get("Access-Control-Allow-Credentials"))
}
func TestCORS_DisallowedOrigin_NoHeader(t *testing.T) {
mw := middleware.CORS([]string{"https://dashboard.veylant.ai"})(corsOkHandler)
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
req.Header.Set("Origin", "http://evil.example.com")
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
}
func TestCORS_Preflight_Returns204(t *testing.T) {
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
req.Header.Set("Origin", "http://localhost:3000")
req.Header.Set("Access-Control-Request-Method", "POST")
req.Header.Set("Access-Control-Request-Headers", "authorization, content-type")
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusNoContent, rec.Code)
assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", rec.Header().Get("Access-Control-Allow-Methods"))
assert.Equal(t, "86400", rec.Header().Get("Access-Control-Max-Age"))
}
func TestCORS_Wildcard_AllowsAnyOrigin(t *testing.T) {
mw := middleware.CORS([]string{"*"})(corsOkHandler)
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
req.Header.Set("Origin", "http://any.example.com")
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Equal(t, "http://any.example.com", rec.Header().Get("Access-Control-Allow-Origin"))
}
func TestCORS_NoOriginHeader_NoCorHeaders(t *testing.T) {
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
// No Origin header → same-origin request → no CORS headers needed.
rec := httptest.NewRecorder()
mw.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
}