60 lines
1.9 KiB
Go
60 lines
1.9 KiB
Go
package middleware_test
|
|
|
|
import (
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/veylant/ia-gateway/internal/middleware"
|
|
)
|
|
|
|
func TestRequestID_GeneratesUUID(t *testing.T) {
|
|
var capturedID string
|
|
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
capturedID = middleware.RequestIDFromContext(r.Context())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rec := httptest.NewRecorder()
|
|
middleware.RequestID(next).ServeHTTP(rec, req)
|
|
|
|
require.NotEmpty(t, capturedID, "request ID must be injected into context")
|
|
assert.Equal(t, capturedID, rec.Header().Get("X-Request-Id"), "response header must match context ID")
|
|
}
|
|
|
|
func TestRequestID_PropagatesExistingHeader(t *testing.T) {
|
|
const existingID = "01956a00-0000-7000-8000-000000000001"
|
|
|
|
var capturedID string
|
|
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
|
|
capturedID = middleware.RequestIDFromContext(r.Context())
|
|
})
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
req.Header.Set("X-Request-Id", existingID)
|
|
rec := httptest.NewRecorder()
|
|
middleware.RequestID(next).ServeHTTP(rec, req)
|
|
|
|
assert.Equal(t, existingID, capturedID, "existing request ID must be preserved in context")
|
|
assert.Equal(t, existingID, rec.Header().Get("X-Request-Id"), "existing request ID must be echoed in response header")
|
|
}
|
|
|
|
func TestRequestID_DifferentIDsPerRequest(t *testing.T) {
|
|
ids := make([]string, 3)
|
|
next := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {})
|
|
|
|
recs := make([]*httptest.ResponseRecorder, 3)
|
|
for i := range ids {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
recs[i] = httptest.NewRecorder()
|
|
middleware.RequestID(next).ServeHTTP(recs[i], req)
|
|
ids[i] = recs[i].Header().Get("X-Request-Id")
|
|
}
|
|
|
|
assert.NotEqual(t, ids[0], ids[1], "each request must get a unique ID")
|
|
assert.NotEqual(t, ids[1], ids[2], "each request must get a unique ID")
|
|
}
|