package middleware_test import ( "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/veylant/ia-gateway/internal/middleware" ) var corsOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) func TestCORS_NoOrigins_NoHeaders(t *testing.T) { mw := middleware.CORS(nil)(corsOkHandler) req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) req.Header.Set("Origin", "http://localhost:3000") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_AllowedOrigin_SetsHeader(t *testing.T) { mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler) req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) req.Header.Set("Origin", "http://localhost:3000") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "true", rec.Header().Get("Access-Control-Allow-Credentials")) } func TestCORS_DisallowedOrigin_NoHeader(t *testing.T) { mw := middleware.CORS([]string{"https://dashboard.veylant.ai"})(corsOkHandler) req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) req.Header.Set("Origin", "http://evil.example.com") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_Preflight_Returns204(t *testing.T) { mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler) req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil) req.Header.Set("Origin", "http://localhost:3000") req.Header.Set("Access-Control-Request-Method", "POST") req.Header.Set("Access-Control-Request-Headers", "authorization, content-type") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusNoContent, rec.Code) assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin")) assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", rec.Header().Get("Access-Control-Allow-Methods")) assert.Equal(t, "86400", rec.Header().Get("Access-Control-Max-Age")) } func TestCORS_Wildcard_AllowsAnyOrigin(t *testing.T) { mw := middleware.CORS([]string{"*"})(corsOkHandler) req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) req.Header.Set("Origin", "http://any.example.com") rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Equal(t, "http://any.example.com", rec.Header().Get("Access-Control-Allow-Origin")) } func TestCORS_NoOriginHeader_NoCorHeaders(t *testing.T) { mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler) req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil) // No Origin header → same-origin request → no CORS headers needed. rec := httptest.NewRecorder() mw.ServeHTTP(rec, req) assert.Equal(t, http.StatusOK, rec.Code) assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin")) }