"""Pseudonymization and de-pseudonymization of PII entities. Tokens: [PII::<8-hex-chars>] Mappings stored in Redis encrypted with AES-256-GCM. """ from __future__ import annotations import base64 import os import re import uuid from typing import Optional from cryptography.hazmat.primitives.ciphers.aead import AESGCM from layers.regex_layer import DetectedEntity # --------------------------------------------------------------------------- # AES-256-GCM encryptor # --------------------------------------------------------------------------- _TOKEN_RE = re.compile(r"\[PII:[A-Z_]+:[0-9a-f]{8}\]") class AESEncryptor: """AES-256-GCM encrypt/decrypt using a 32-byte key.""" _NONCE_SIZE = 12 # bytes def __init__(self, key_b64: str) -> None: key = base64.b64decode(key_b64) if len(key) != 32: raise ValueError(f"Encryption key must be 32 bytes, got {len(key)}") self._aesgcm = AESGCM(key) def encrypt(self, plaintext: str) -> bytes: """Return nonce (12 bytes) + ciphertext.""" nonce = os.urandom(self._NONCE_SIZE) ciphertext = self._aesgcm.encrypt(nonce, plaintext.encode(), None) return nonce + ciphertext def decrypt(self, data: bytes) -> str: """Decrypt nonce-prefixed ciphertext and return plaintext string.""" nonce, ciphertext = data[: self._NONCE_SIZE], data[self._NONCE_SIZE :] return self._aesgcm.decrypt(nonce, ciphertext, None).decode() # --------------------------------------------------------------------------- # PseudonymMapper # --------------------------------------------------------------------------- class PseudonymMapper: """Map PII entities to pseudonym tokens, persisted in Redis.""" def __init__(self, redis_client, encryptor: AESEncryptor, ttl_seconds: int = 3600) -> None: self._redis = redis_client self._encryptor = encryptor self._ttl = ttl_seconds # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def anonymize( self, text: str, entities: list[DetectedEntity], tenant_id: str, request_id: str, ) -> tuple[str, dict[str, str]]: """Replace PII entities in *text* with pseudonym tokens. Returns: (anonymized_text, mapping) where mapping is {token → original_value}. """ mapping: dict[str, str] = {} # Process entities from end to start to preserve offsets sorted_entities = sorted(entities, key=lambda e: e.start, reverse=True) result = text for entity in sorted_entities: short_id = uuid.uuid4().hex[:8] token = f"[PII:{entity.entity_type}:{short_id}]" mapping[token] = entity.original_value # Replace in text (slice is safe because we process right-to-left) result = result[: entity.start] + token + result[entity.end :] # Store encrypted mapping in Redis redis_key = f"pii:{tenant_id}:{request_id}:{short_id}" encrypted = self._encryptor.encrypt(entity.original_value) self._redis.set(redis_key, encrypted, ex=self._ttl) return result, mapping def depseudonymize(self, text: str, mapping: dict[str, str]) -> str: """Replace pseudonym tokens in *text* with their original values.""" def replace(m: re.Match) -> str: token = m.group(0) return mapping.get(token, token) return _TOKEN_RE.sub(replace, text) def load_mapping_from_redis( self, tenant_id: str, request_id: str ) -> dict[str, str]: """Reconstruct the token → original mapping from Redis for a given request.""" pattern = f"pii:{tenant_id}:{request_id}:*" mapping: dict[str, str] = {} for key in self._redis.scan_iter(pattern): encrypted = self._redis.get(key) if encrypted is None: continue short_id = key.decode().split(":")[-1] if isinstance(key, bytes) else key.split(":")[-1] original = self._encryptor.decrypt(encrypted) # We don't know the entity_type from the key alone — token must be in text # This method is for reference; in practice the mapping is passed in-memory mapping[short_id] = original return mapping