81 lines
3.2 KiB
Go
81 lines
3.2 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
)
|
|
|
|
var corsOkHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
func TestCORS_NoOrigins_NoHeaders(t *testing.T) {
|
|
mw := middleware.CORS(nil)(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
|
req.Header.Set("Origin", "http://localhost:3000")
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
func TestCORS_AllowedOrigin_SetsHeader(t *testing.T) {
|
|
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
|
req.Header.Set("Origin", "http://localhost:3000")
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
assert.Equal(t, "true", rec.Header().Get("Access-Control-Allow-Credentials"))
|
|
}
|
|
|
|
func TestCORS_DisallowedOrigin_NoHeader(t *testing.T) {
|
|
mw := middleware.CORS([]string{"https://dashboard.veylant.ai"})(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
|
req.Header.Set("Origin", "http://evil.example.com")
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
func TestCORS_Preflight_Returns204(t *testing.T) {
|
|
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodOptions, "/v1/chat/completions", nil)
|
|
req.Header.Set("Origin", "http://localhost:3000")
|
|
req.Header.Set("Access-Control-Request-Method", "POST")
|
|
req.Header.Set("Access-Control-Request-Headers", "authorization, content-type")
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusNoContent, rec.Code)
|
|
assert.Equal(t, "http://localhost:3000", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
assert.Equal(t, "GET, POST, PUT, PATCH, DELETE, OPTIONS", rec.Header().Get("Access-Control-Allow-Methods"))
|
|
assert.Equal(t, "86400", rec.Header().Get("Access-Control-Max-Age"))
|
|
}
|
|
|
|
func TestCORS_Wildcard_AllowsAnyOrigin(t *testing.T) {
|
|
mw := middleware.CORS([]string{"*"})(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
|
req.Header.Set("Origin", "http://any.example.com")
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Equal(t, "http://any.example.com", rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|
|
|
|
func TestCORS_NoOriginHeader_NoCorHeaders(t *testing.T) {
|
|
mw := middleware.CORS([]string{"http://localhost:3000"})(corsOkHandler)
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/chat/completions", nil)
|
|
// No Origin header → same-origin request → no CORS headers needed.
|
|
rec := httptest.NewRecorder()
|
|
mw.ServeHTTP(rec, req)
|
|
assert.Equal(t, http.StatusOK, rec.Code)
|
|
assert.Empty(t, rec.Header().Get("Access-Control-Allow-Origin"))
|
|
}
|