package pii_test import ( "context" "net" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" "google.golang.org/grpc" piiv1 "github.com/veylant/ia-gateway/gen/pii/v1" "github.com/veylant/ia-gateway/internal/pii" ) // ─── Mock gRPC server ──────────────────────────────────────────────────────── type mockPiiServer struct { piiv1.UnimplementedPiiServiceServer detectFn func(*piiv1.PiiRequest) (*piiv1.PiiResponse, error) } func (m *mockPiiServer) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) { if m.detectFn != nil { return m.detectFn(req) } return &piiv1.PiiResponse{ AnonymizedText: req.Text, ProcessingTimeMs: 1, }, nil } func (m *mockPiiServer) Health(_ context.Context, _ *piiv1.HealthRequest) (*piiv1.HealthResponse, error) { return &piiv1.HealthResponse{Status: "ok", NerModelLoaded: true}, nil } // startMockServer starts a gRPC server with the given servicer and returns its address. func startMockServer(t *testing.T, srv piiv1.PiiServiceServer) string { t.Helper() lis, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) s := grpc.NewServer() piiv1.RegisterPiiServiceServer(s, srv) t.Cleanup(func() { s.Stop() }) go func() { _ = s.Serve(lis) }() return lis.Addr().String() } func newTestClient(t *testing.T, addr string, failOpen bool) *pii.Client { t.Helper() c, err := pii.New(pii.Config{ Address: addr, Timeout: 500 * time.Millisecond, FailOpen: failOpen, }, zap.NewNop()) require.NoError(t, err) t.Cleanup(func() { _ = c.Close() }) return c } // ─── Tests ─────────────────────────────────────────────────────────────────── func TestClient_Detect_Nominal(t *testing.T) { srv := &mockPiiServer{ detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) { return &piiv1.PiiResponse{ AnonymizedText: "[PII:EMAIL:abc12345]", Entities: []*piiv1.PiiEntity{ { EntityType: "EMAIL", OriginalValue: "alice@example.com", Pseudonym: "[PII:EMAIL:abc12345]", Start: 0, End: 17, Confidence: 1.0, DetectionLayer: "regex", }, }, ProcessingTimeMs: 2, }, nil }, } addr := startMockServer(t, srv) client := newTestClient(t, addr, false) result, err := client.Detect(context.Background(), "alice@example.com", "tenant1", "req1", false, false) require.NoError(t, err) assert.Equal(t, "[PII:EMAIL:abc12345]", result.AnonymizedText) require.Len(t, result.Entities, 1) assert.Equal(t, "EMAIL", result.Entities[0].EntityType) assert.Equal(t, "alice@example.com", result.Entities[0].OriginalValue) assert.Equal(t, "[PII:EMAIL:abc12345]", result.Entities[0].Pseudonym) assert.Equal(t, int64(2), result.ProcessingTimeMs) } func TestClient_Detect_NoEntities(t *testing.T) { addr := startMockServer(t, &mockPiiServer{}) client := newTestClient(t, addr, false) result, err := client.Detect(context.Background(), "Bonjour le monde", "t", "r", false, false) require.NoError(t, err) assert.Equal(t, "Bonjour le monde", result.AnonymizedText) assert.Empty(t, result.Entities) } func TestClient_Detect_FailOpen_ReturnsOriginalOnError(t *testing.T) { // Point to a port where nothing is listening client := newTestClient(t, "127.0.0.1:19999", true) result, err := client.Detect(context.Background(), "original text", "t", "r", false, false) require.NoError(t, err) // fail_open: no error returned assert.Equal(t, "original text", result.AnonymizedText) assert.Empty(t, result.Entities) } func TestClient_Detect_FailClosed_ReturnsError(t *testing.T) { client := newTestClient(t, "127.0.0.1:19998", false) _, err := client.Detect(context.Background(), "original text", "t", "r", false, false) require.Error(t, err) assert.Contains(t, err.Error(), "pii:") } func TestClient_Detect_PassesEnableNERFlag(t *testing.T) { var receivedEnableNER bool srv := &mockPiiServer{ detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) { if req.Options != nil { receivedEnableNER = req.Options.EnableNer } return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil }, } addr := startMockServer(t, srv) client := newTestClient(t, addr, false) _, err := client.Detect(context.Background(), "text", "t", "r", true, false) require.NoError(t, err) assert.True(t, receivedEnableNER) }