150 lines
4.5 KiB
Go
150 lines
4.5 KiB
Go
package pii_test
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/grpc"
|
|
|
|
piiv1 "github.com/veylant/ia-gateway/gen/pii/v1"
|
|
"github.com/veylant/ia-gateway/internal/pii"
|
|
)
|
|
|
|
// ─── Mock gRPC server ────────────────────────────────────────────────────────
|
|
|
|
type mockPiiServer struct {
|
|
piiv1.UnimplementedPiiServiceServer
|
|
detectFn func(*piiv1.PiiRequest) (*piiv1.PiiResponse, error)
|
|
}
|
|
|
|
func (m *mockPiiServer) Detect(_ context.Context, req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
|
|
if m.detectFn != nil {
|
|
return m.detectFn(req)
|
|
}
|
|
return &piiv1.PiiResponse{
|
|
AnonymizedText: req.Text,
|
|
ProcessingTimeMs: 1,
|
|
}, nil
|
|
}
|
|
|
|
func (m *mockPiiServer) Health(_ context.Context, _ *piiv1.HealthRequest) (*piiv1.HealthResponse, error) {
|
|
return &piiv1.HealthResponse{Status: "ok", NerModelLoaded: true}, nil
|
|
}
|
|
|
|
// startMockServer starts a gRPC server with the given servicer and returns its address.
|
|
func startMockServer(t *testing.T, srv piiv1.PiiServiceServer) string {
|
|
t.Helper()
|
|
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
|
require.NoError(t, err)
|
|
|
|
s := grpc.NewServer()
|
|
piiv1.RegisterPiiServiceServer(s, srv)
|
|
|
|
t.Cleanup(func() { s.Stop() })
|
|
go func() { _ = s.Serve(lis) }()
|
|
|
|
return lis.Addr().String()
|
|
}
|
|
|
|
func newTestClient(t *testing.T, addr string, failOpen bool) *pii.Client {
|
|
t.Helper()
|
|
c, err := pii.New(pii.Config{
|
|
Address: addr,
|
|
Timeout: 500 * time.Millisecond,
|
|
FailOpen: failOpen,
|
|
}, zap.NewNop())
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { _ = c.Close() })
|
|
return c
|
|
}
|
|
|
|
// ─── Tests ───────────────────────────────────────────────────────────────────
|
|
|
|
func TestClient_Detect_Nominal(t *testing.T) {
|
|
srv := &mockPiiServer{
|
|
detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
|
|
return &piiv1.PiiResponse{
|
|
AnonymizedText: "[PII:EMAIL:abc12345]",
|
|
Entities: []*piiv1.PiiEntity{
|
|
{
|
|
EntityType: "EMAIL",
|
|
OriginalValue: "alice@example.com",
|
|
Pseudonym: "[PII:EMAIL:abc12345]",
|
|
Start: 0,
|
|
End: 17,
|
|
Confidence: 1.0,
|
|
DetectionLayer: "regex",
|
|
},
|
|
},
|
|
ProcessingTimeMs: 2,
|
|
}, nil
|
|
},
|
|
}
|
|
addr := startMockServer(t, srv)
|
|
client := newTestClient(t, addr, false)
|
|
|
|
result, err := client.Detect(context.Background(), "alice@example.com", "tenant1", "req1", false, false)
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "[PII:EMAIL:abc12345]", result.AnonymizedText)
|
|
require.Len(t, result.Entities, 1)
|
|
assert.Equal(t, "EMAIL", result.Entities[0].EntityType)
|
|
assert.Equal(t, "alice@example.com", result.Entities[0].OriginalValue)
|
|
assert.Equal(t, "[PII:EMAIL:abc12345]", result.Entities[0].Pseudonym)
|
|
assert.Equal(t, int64(2), result.ProcessingTimeMs)
|
|
}
|
|
|
|
func TestClient_Detect_NoEntities(t *testing.T) {
|
|
addr := startMockServer(t, &mockPiiServer{})
|
|
client := newTestClient(t, addr, false)
|
|
|
|
result, err := client.Detect(context.Background(), "Bonjour le monde", "t", "r", false, false)
|
|
|
|
require.NoError(t, err)
|
|
assert.Equal(t, "Bonjour le monde", result.AnonymizedText)
|
|
assert.Empty(t, result.Entities)
|
|
}
|
|
|
|
func TestClient_Detect_FailOpen_ReturnsOriginalOnError(t *testing.T) {
|
|
// Point to a port where nothing is listening
|
|
client := newTestClient(t, "127.0.0.1:19999", true)
|
|
|
|
result, err := client.Detect(context.Background(), "original text", "t", "r", false, false)
|
|
|
|
require.NoError(t, err) // fail_open: no error returned
|
|
assert.Equal(t, "original text", result.AnonymizedText)
|
|
assert.Empty(t, result.Entities)
|
|
}
|
|
|
|
func TestClient_Detect_FailClosed_ReturnsError(t *testing.T) {
|
|
client := newTestClient(t, "127.0.0.1:19998", false)
|
|
|
|
_, err := client.Detect(context.Background(), "original text", "t", "r", false, false)
|
|
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "pii:")
|
|
}
|
|
|
|
func TestClient_Detect_PassesEnableNERFlag(t *testing.T) {
|
|
var receivedEnableNER bool
|
|
srv := &mockPiiServer{
|
|
detectFn: func(req *piiv1.PiiRequest) (*piiv1.PiiResponse, error) {
|
|
if req.Options != nil {
|
|
receivedEnableNER = req.Options.EnableNer
|
|
}
|
|
return &piiv1.PiiResponse{AnonymizedText: req.Text}, nil
|
|
},
|
|
}
|
|
addr := startMockServer(t, srv)
|
|
client := newTestClient(t, addr, false)
|
|
|
|
_, err := client.Detect(context.Background(), "text", "t", "r", true, false)
|
|
require.NoError(t, err)
|
|
assert.True(t, receivedEnableNER)
|
|
}
|