veylant/internal/pii/client_test.go
2026-02-23 13:35:04 +01:00

150 lines
4.5 KiB
Go

package pii_test
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/zap"
"google.golang.org/grpc"
piiv1 "github.com/veylant/ia-gateway/gen/pii/v1"
"github.com/veylant/ia-gateway/internal/pii"
)
// ─── Mock gRPC server ────────────────────────────────────────────────────────
type mockPiiServer struct {
piiv1.UnimplementedPiiServiceServer
detectFn func(*piiv1.PiiRequest) (*piiv1.PiiResponse, error)
}
func (m *mockPiiServer) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
if m.detectFn != nil {
return m.detectFn(req)
}
return &piiv1.PiiResponse{
AnonymizedText: req.Text,
ProcessingTimeMs: 1,
}, nil
}
func (m *mockPiiServer) Health(_ context.Context, _ *piiv1.HealthRequest) (*piiv1.HealthResponse, error) {
return &piiv1.HealthResponse{Status: "ok", NerModelLoaded: true}, nil
}
// startMockServer starts a gRPC server with the given servicer and returns its address.
func startMockServer(t *testing.T, srv piiv1.PiiServiceServer) string {
t.Helper()
lis, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
s := grpc.NewServer()
piiv1.RegisterPiiServiceServer(s, srv)
t.Cleanup(func() { s.Stop() })
go func() { _ = s.Serve(lis) }()
return lis.Addr().String()
}
func newTestClient(t *testing.T, addr string, failOpen bool) *pii.Client {
t.Helper()
c, err := pii.New(pii.Config{
Address: addr,
Timeout: 500 * time.Millisecond,
FailOpen: failOpen,
}, zap.NewNop())
require.NoError(t, err)
t.Cleanup(func() { _ = c.Close() })
return c
}
// ─── Tests ───────────────────────────────────────────────────────────────────
func TestClient_Detect_Nominal(t *testing.T) {
srv := &mockPiiServer{
detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
return &piiv1.PiiResponse{
AnonymizedText: "[PII:EMAIL:abc12345]",
Entities: []*piiv1.PiiEntity{
{
EntityType: "EMAIL",
OriginalValue: "alice@example.com",
Pseudonym: "[PII:EMAIL:abc12345]",
Start: 0,
End: 17,
Confidence: 1.0,
DetectionLayer: "regex",
},
},
ProcessingTimeMs: 2,
}, nil
},
}
addr := startMockServer(t, srv)
client := newTestClient(t, addr, false)
result, err := client.Detect(context.Background(), "alice@example.com", "tenant1", "req1", false, false)
require.NoError(t, err)
assert.Equal(t, "[PII:EMAIL:abc12345]", result.AnonymizedText)
require.Len(t, result.Entities, 1)
assert.Equal(t, "EMAIL", result.Entities[0].EntityType)
assert.Equal(t, "alice@example.com", result.Entities[0].OriginalValue)
assert.Equal(t, "[PII:EMAIL:abc12345]", result.Entities[0].Pseudonym)
assert.Equal(t, int64(2), result.ProcessingTimeMs)
}
func TestClient_Detect_NoEntities(t *testing.T) {
addr := startMockServer(t, &mockPiiServer{})
client := newTestClient(t, addr, false)
result, err := client.Detect(context.Background(), "Bonjour le monde", "t", "r", false, false)
require.NoError(t, err)
assert.Equal(t, "Bonjour le monde", result.AnonymizedText)
assert.Empty(t, result.Entities)
}
func TestClient_Detect_FailOpen_ReturnsOriginalOnError(t *testing.T) {
// Point to a port where nothing is listening
client := newTestClient(t, "127.0.0.1:19999", true)
result, err := client.Detect(context.Background(), "original text", "t", "r", false, false)
require.NoError(t, err) // fail_open: no error returned
assert.Equal(t, "original text", result.AnonymizedText)
assert.Empty(t, result.Entities)
}
func TestClient_Detect_FailClosed_ReturnsError(t *testing.T) {
client := newTestClient(t, "127.0.0.1:19998", false)
_, err := client.Detect(context.Background(), "original text", "t", "r", false, false)
require.Error(t, err)
assert.Contains(t, err.Error(), "pii:")
}
func TestClient_Detect_PassesEnableNERFlag(t *testing.T) {
var receivedEnableNER bool
srv := &mockPiiServer{
detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
if req.Options != nil {
receivedEnableNER = req.Options.EnableNer
}
return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil
},
}
addr := startMockServer(t, srv)
client := newTestClient(t, addr, false)
_, err := client.Detect(context.Background(), "text", "t", "r", true, false)
require.NoError(t, err)
assert.True(t, receivedEnableNER)
}