veylant/internal/auditlog/ch_logger.go
2026-02-23 13:35:04 +01:00

254 lines
7.6 KiB
Go

package auditlog
import (
"context"
"fmt"
"os"
"sort"
"strings"
"time"
"github.com/ClickHouse/clickhouse-go/v2"
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
"go.uber.org/zap"
)
// ClickHouseLogger implements Logger + Flusher backed by a ClickHouse connection.
// Query/QueryCosts perform synchronous CH queries for the admin API.
// Log() is non-blocking: entries are queued in BatchWriter (not directly here).
type ClickHouseLogger struct {
conn driver.Conn
logger *zap.Logger
bw *BatchWriter
}
// NewClickHouseLogger opens a ClickHouse native connection from a DSN string
// (clickhouse://user:pass@host:9000/database) and returns a ClickHouseLogger.
// The caller must call Start() and defer Stop().
func NewClickHouseLogger(dsn string, maxConns, dialTimeoutSec int, logger *zap.Logger) (*ClickHouseLogger, error) {
opts, err := clickhouse.ParseDSN(dsn)
if err != nil {
return nil, fmt.Errorf("clickhouse: parse DSN: %w", err)
}
if maxConns > 0 {
opts.MaxOpenConns = maxConns
}
if dialTimeoutSec > 0 {
opts.DialTimeout = time.Duration(dialTimeoutSec) * time.Second
}
conn, err := clickhouse.Open(opts)
if err != nil {
return nil, fmt.Errorf("clickhouse: open: %w", err)
}
if err := conn.Ping(context.Background()); err != nil {
return nil, fmt.Errorf("clickhouse: ping: %w", err)
}
ch := &ClickHouseLogger{conn: conn, logger: logger}
ch.bw = NewBatchWriter(ch, logger)
return ch, nil
}
// ApplyDDL reads and executes the ClickHouse DDL file at startup (idempotent).
func (c *ClickHouseLogger) ApplyDDL(sqlPath string) error {
data, err := os.ReadFile(sqlPath)
if err != nil {
return fmt.Errorf("clickhouse: read DDL %s: %w", sqlPath, err)
}
// Split on semicolons to handle multi-statement files.
for _, stmt := range strings.Split(string(data), ";") {
stmt = strings.TrimSpace(stmt)
if stmt == "" || strings.HasPrefix(stmt, "--") {
continue
}
if err := c.conn.Exec(context.Background(), stmt); err != nil {
return fmt.Errorf("clickhouse: exec DDL: %w", err)
}
}
return nil
}
// ─── Logger interface ─────────────────────────────────────────────────────────
func (c *ClickHouseLogger) Log(entry AuditEntry) { c.bw.Log(entry) }
func (c *ClickHouseLogger) Start() { c.bw.Start() }
func (c *ClickHouseLogger) Stop() { c.bw.Stop() }
// ─── Flusher interface ────────────────────────────────────────────────────────
func (c *ClickHouseLogger) InsertBatch(ctx context.Context, entries []AuditEntry) error {
batch, err := c.conn.PrepareBatch(ctx, "INSERT INTO audit_logs")
if err != nil {
return fmt.Errorf("clickhouse: prepare batch: %w", err)
}
for _, e := range entries {
if err := batch.Append(
e.RequestID,
e.TenantID,
e.UserID,
e.Timestamp,
e.ModelRequested,
e.ModelUsed,
e.Provider,
e.Department,
e.UserRole,
e.PromptHash,
e.ResponseHash,
e.PromptAnonymized,
e.SensitivityLevel,
uint32(e.TokenInput),
uint32(e.TokenOutput),
uint32(e.TokenTotal),
e.CostUSD,
uint32(e.LatencyMs),
e.Status,
e.ErrorType,
uint16(e.PIIEntityCount),
e.Stream,
); err != nil {
return fmt.Errorf("clickhouse: append row: %w", err)
}
}
return batch.Send()
}
// ─── Query ────────────────────────────────────────────────────────────────────
func (c *ClickHouseLogger) Query(ctx context.Context, q AuditQuery) (*AuditResult, error) {
limit := q.Limit
if limit <= 0 || limit > 200 {
limit = 50
}
offset := q.Offset
var conditions []string
var args []interface{}
conditions = append(conditions, "tenant_id = ?")
args = append(args, q.TenantID)
if !q.StartTime.IsZero() {
conditions = append(conditions, "timestamp >= ?")
args = append(args, q.StartTime)
}
if !q.EndTime.IsZero() {
conditions = append(conditions, "timestamp <= ?")
args = append(args, q.EndTime)
}
if q.UserID != "" {
conditions = append(conditions, "user_id = ?")
args = append(args, q.UserID)
}
if q.Provider != "" {
conditions = append(conditions, "provider = ?")
args = append(args, q.Provider)
}
sensitivityOrder := map[string]int{"none": 0, "low": 1, "medium": 2, "high": 3, "critical": 4}
if _, ok := sensitivityOrder[q.MinSensitivity]; ok && q.MinSensitivity != "" {
levels := []string{}
minLvl := sensitivityOrder[q.MinSensitivity]
for lvl, ord := range sensitivityOrder {
if ord >= minLvl {
levels = append(levels, "'"+lvl+"'")
}
}
conditions = append(conditions, "sensitivity_level IN ("+strings.Join(levels, ",")+")")
}
where := strings.Join(conditions, " AND ")
query := fmt.Sprintf(
"SELECT request_id, tenant_id, user_id, timestamp, model_requested, model_used, provider, "+
"department, user_role, prompt_hash, response_hash, sensitivity_level, "+
"token_input, token_output, token_total, cost_usd, latency_ms, status, "+
"error_type, pii_entity_count, stream FROM audit_logs WHERE %s "+
"ORDER BY timestamp DESC LIMIT %d OFFSET %d",
where, limit, offset,
)
rows, err := c.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("clickhouse: query logs: %w", err)
}
defer rows.Close()
var entries []AuditEntry
for rows.Next() {
var e AuditEntry
var tokenIn, tokenOut, tokenTotal uint32
var latencyMs uint32
var piiCount uint16
if err := rows.Scan(
&e.RequestID, &e.TenantID, &e.UserID, &e.Timestamp,
&e.ModelRequested, &e.ModelUsed, &e.Provider,
&e.Department, &e.UserRole, &e.PromptHash, &e.ResponseHash,
&e.SensitivityLevel, &tokenIn, &tokenOut, &tokenTotal,
&e.CostUSD, &latencyMs, &e.Status, &e.ErrorType, &piiCount, &e.Stream,
); err != nil {
return nil, fmt.Errorf("clickhouse: scan: %w", err)
}
e.TokenInput = int(tokenIn)
e.TokenOutput = int(tokenOut)
e.TokenTotal = int(tokenTotal)
e.LatencyMs = int(latencyMs)
e.PIIEntityCount = int(piiCount)
// prompt_anonymized is intentionally excluded from query results.
entries = append(entries, e)
}
return &AuditResult{Data: entries, Total: len(entries)}, nil
}
func (c *ClickHouseLogger) QueryCosts(ctx context.Context, q CostQuery) (*CostResult, error) {
groupField := "provider"
switch q.GroupBy {
case "model":
groupField = "model_used"
case "department":
groupField = "department"
}
var conditions []string
var args []interface{}
conditions = append(conditions, "tenant_id = ?")
args = append(args, q.TenantID)
if !q.StartTime.IsZero() {
conditions = append(conditions, "timestamp >= ?")
args = append(args, q.StartTime)
}
if !q.EndTime.IsZero() {
conditions = append(conditions, "timestamp <= ?")
args = append(args, q.EndTime)
}
where := strings.Join(conditions, " AND ")
query := fmt.Sprintf(
"SELECT %s, sum(token_total), sum(cost_usd), count() FROM audit_logs WHERE %s GROUP BY %s ORDER BY %s",
groupField, where, groupField, groupField,
)
rows, err := c.conn.Query(ctx, query, args...)
if err != nil {
return nil, fmt.Errorf("clickhouse: query costs: %w", err)
}
defer rows.Close()
var data []CostSummary
for rows.Next() {
var s CostSummary
var tokens uint64
var count uint64
if err := rows.Scan(&s.Key, &tokens, &s.TotalCostUSD, &count); err != nil {
return nil, fmt.Errorf("clickhouse: scan cost: %w", err)
}
s.TotalTokens = int(tokens)
s.RequestCount = int(count)
data = append(data, s)
}
sort.Slice(data, func(i, j int) bool { return data[i].Key < data[j].Key })
return &CostResult{Data: data}, nil
}