veylant/services/pii/main.py
2026-02-23 13:35:04 +01:00

242 lines
8.0 KiB
Python

"""Veylant IA — PII Detection Service (Sprint 3)
Implements the full 3-layer PII pipeline:
Layer 1: Regex (IBAN, email, phone, SSN, credit cards)
Layer 2: Presidio + spaCy NER (PERSON, LOCATION, ORGANIZATION)
Pseudonymization: [PII:TYPE:UUID] tokens stored in Redis (AES-256-GCM)
"""
from __future__ import annotations
import asyncio
import logging
import time
from concurrent import futures
from contextlib import asynccontextmanager
import grpc
import redis as redis_module
from fastapi import FastAPI
import config as cfg
from gen.pii.v1 import pii_pb2, pii_pb2_grpc
from pipeline import Pipeline
from pseudonymize import AESEncryptor, PseudonymMapper
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Shared singletons (initialized at startup)
# ---------------------------------------------------------------------------
_pipeline: Pipeline | None = None
_pseudo_mapper: PseudonymMapper | None = None
def _get_pipeline() -> Pipeline:
assert _pipeline is not None, "Pipeline not initialized"
return _pipeline
def _get_mapper() -> PseudonymMapper:
assert _pseudo_mapper is not None, "PseudonymMapper not initialized"
return _pseudo_mapper
# ---------------------------------------------------------------------------
# FastAPI lifespan — warm up NER model before accepting traffic
# ---------------------------------------------------------------------------
@asynccontextmanager
async def lifespan(application: FastAPI):
global _pipeline, _pseudo_mapper
_pipeline = Pipeline()
encryptor = AESEncryptor(cfg.ENCRYPTION_KEY_B64)
redis_client = redis_module.from_url(cfg.REDIS_URL, decode_responses=False)
_pseudo_mapper = PseudonymMapper(redis_client, encryptor, cfg.DEFAULT_TTL)
logger.info("Warming up NER model (this may take a few seconds)…")
try:
await asyncio.get_event_loop().run_in_executor(None, _pipeline.warm_up)
logger.info("NER model ready")
except Exception:
logger.exception("NER warm-up failed — service will use regex only")
yield
logger.info("PII service shutting down")
# ---------------------------------------------------------------------------
# FastAPI HTTP app
# ---------------------------------------------------------------------------
app = FastAPI(title="Veylant PII Service", version="0.3.0", lifespan=lifespan)
@app.get("/healthz")
async def healthz() -> dict:
pipeline = _get_pipeline()
return {
"status": "ok",
"ner_model_loaded": pipeline.ner_layer.is_loaded,
"spacy_model": cfg.SPACY_FR_MODEL,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
# ---------------------------------------------------------------------------
# gRPC servicer
# ---------------------------------------------------------------------------
class PiiServiceServicer(pii_pb2_grpc.PiiServiceServicer):
"""Full gRPC implementation of PiiService."""
def Detect(self, request, context): # noqa: N802
start_ns = time.monotonic_ns()
pipeline = _get_pipeline()
mapper = _get_mapper()
enable_ner = request.options.enable_ner if request.options else True
confidence = (
request.options.confidence_threshold
if request.options and request.options.confidence_threshold > 0
else cfg.NER_CONFIDENCE
)
try:
entities = pipeline.detect(
request.text,
enable_ner=enable_ner,
confidence_threshold=confidence,
)
anonymized_text, _mapping = mapper.anonymize(
request.text,
entities,
tenant_id=request.tenant_id or "default",
request_id=request.request_id or "unknown",
)
except Exception as exc:
logger.exception("PII detection failed")
context.abort(grpc.StatusCode.INTERNAL, str(exc))
return None
elapsed_ms = (time.monotonic_ns() - start_ns) // 1_000_000
proto_entities = [
pii_pb2.PiiEntity(
entity_type=e.entity_type,
original_value=e.original_value,
pseudonym=_mapping.get(
next(
(k for k, v in _mapping.items() if v == e.original_value),
"",
),
"",
),
start=e.start,
end=e.end,
confidence=e.confidence,
detection_layer=e.detection_layer,
)
for e in entities
]
# Rebuild proto_entities with correct pseudonym assignment
proto_entities = _build_proto_entities(entities, _mapping)
return pii_pb2.PiiResponse(
anonymized_text=anonymized_text,
entities=proto_entities,
processing_time_ms=elapsed_ms,
)
def Health(self, request, context): # noqa: N802
pipeline = _get_pipeline()
return pii_pb2.HealthResponse(
status="ok",
ner_model_loaded=pipeline.ner_layer.is_loaded,
spacy_model=cfg.SPACY_FR_MODEL if pipeline.ner_layer.is_loaded else "",
)
def _build_proto_entities(entities, mapping: dict[str, str]) -> list[pii_pb2.PiiEntity]:
"""Build proto PiiEntity list with correct pseudonym tokens."""
# Reverse mapping: original_value → token (multiple entities may share the same
# original value — we emit a separate token per occurrence)
original_to_tokens: dict[str, list[str]] = {}
for token, original in mapping.items():
original_to_tokens.setdefault(original, []).append(token)
# Track which token index to use for each original_value (in entity order)
usage_count: dict[str, int] = {}
result = []
for e in entities:
tokens_for_value = original_to_tokens.get(e.original_value, [])
idx = usage_count.get(e.original_value, 0)
pseudonym = tokens_for_value[idx] if idx < len(tokens_for_value) else ""
usage_count[e.original_value] = idx + 1
result.append(
pii_pb2.PiiEntity(
entity_type=e.entity_type,
original_value=e.original_value,
pseudonym=pseudonym,
start=e.start,
end=e.end,
confidence=e.confidence,
detection_layer=e.detection_layer,
)
)
return result
# ---------------------------------------------------------------------------
# gRPC async server
# ---------------------------------------------------------------------------
async def serve_grpc() -> None:
server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_send_message_length", 10 * 1024 * 1024),
("grpc.max_receive_message_length", 10 * 1024 * 1024),
],
)
pii_pb2_grpc.add_PiiServiceServicer_to_server(PiiServiceServicer(), server)
listen_addr = f"[::]:{cfg.GRPC_PORT}"
server.add_insecure_port(listen_addr)
await server.start()
logger.info("gRPC server listening on %s", listen_addr)
await server.wait_for_termination()
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
async def main() -> None:
grpc_task = asyncio.create_task(serve_grpc())
uvicorn_cfg = uvicorn.Config(
app, host="0.0.0.0", port=cfg.HTTP_PORT, log_level="info"
)
uvicorn_server = uvicorn.Server(uvicorn_cfg)
http_task = asyncio.create_task(uvicorn_server.serve())
logger.info(
"PII service starting (HTTP :%d, gRPC :%d)", cfg.HTTP_PORT, cfg.GRPC_PORT
)
await asyncio.gather(grpc_task, http_task)
asyncio.run(main())