package auditlog import ( "context" "fmt" "os" "sort" "strings" "time" "github.com/ClickHouse/clickhouse-go/v2" "github.com/ClickHouse/clickhouse-go/v2/lib/driver" "go.uber.org/zap" ) // ClickHouseLogger implements Logger + Flusher backed by a ClickHouse connection. // Query/QueryCosts perform synchronous CH queries for the admin API. // Log() is non-blocking: entries are queued in BatchWriter (not directly here). type ClickHouseLogger struct { conn driver.Conn logger *zap.Logger bw *BatchWriter } // NewClickHouseLogger opens a ClickHouse native connection from a DSN string // (clickhouse://user:pass@host:9000/database) and returns a ClickHouseLogger. // The caller must call Start() and defer Stop(). func NewClickHouseLogger(dsn string, maxConns, dialTimeoutSec int, logger *zap.Logger) (*ClickHouseLogger, error) { opts, err := clickhouse.ParseDSN(dsn) if err != nil { return nil, fmt.Errorf("clickhouse: parse DSN: %w", err) } if maxConns > 0 { opts.MaxOpenConns = maxConns } if dialTimeoutSec > 0 { opts.DialTimeout = time.Duration(dialTimeoutSec) * time.Second } conn, err := clickhouse.Open(opts) if err != nil { return nil, fmt.Errorf("clickhouse: open: %w", err) } if err := conn.Ping(context.Background()); err != nil { return nil, fmt.Errorf("clickhouse: ping: %w", err) } ch := &ClickHouseLogger{conn: conn, logger: logger} ch.bw = NewBatchWriter(ch, logger) return ch, nil } // ApplyDDL reads and executes the ClickHouse DDL file at startup (idempotent). func (c *ClickHouseLogger) ApplyDDL(sqlPath string) error { data, err := os.ReadFile(sqlPath) if err != nil { return fmt.Errorf("clickhouse: read DDL %s: %w", sqlPath, err) } // Split on semicolons to handle multi-statement files. for _, stmt := range strings.Split(string(data), ";") { stmt = strings.TrimSpace(stmt) if stmt == "" || strings.HasPrefix(stmt, "--") { continue } if err := c.conn.Exec(context.Background(), stmt); err != nil { return fmt.Errorf("clickhouse: exec DDL: %w", err) } } return nil } // ─── Logger interface ───────────────────────────────────────────────────────── func (c *ClickHouseLogger) Log(entry AuditEntry) { c.bw.Log(entry) } func (c *ClickHouseLogger) Start() { c.bw.Start() } func (c *ClickHouseLogger) Stop() { c.bw.Stop() } // ─── Flusher interface ──────────────────────────────────────────────────────── func (c *ClickHouseLogger) InsertBatch(ctx context.Context, entries []AuditEntry) error { batch, err := c.conn.PrepareBatch(ctx, "INSERT INTO audit_logs") if err != nil { return fmt.Errorf("clickhouse: prepare batch: %w", err) } for _, e := range entries { if err := batch.Append( e.RequestID, e.TenantID, e.UserID, e.Timestamp, e.ModelRequested, e.ModelUsed, e.Provider, e.Department, e.UserRole, e.PromptHash, e.ResponseHash, e.PromptAnonymized, e.SensitivityLevel, uint32(e.TokenInput), uint32(e.TokenOutput), uint32(e.TokenTotal), e.CostUSD, uint32(e.LatencyMs), e.Status, e.ErrorType, uint16(e.PIIEntityCount), e.Stream, ); err != nil { return fmt.Errorf("clickhouse: append row: %w", err) } } return batch.Send() } // ─── Query ──────────────────────────────────────────────────────────────────── func (c *ClickHouseLogger) Query(ctx context.Context, q AuditQuery) (*AuditResult, error) { limit := q.Limit if limit <= 0 || limit > 200 { limit = 50 } offset := q.Offset var conditions []string var args []interface{} conditions = append(conditions, "tenant_id = ?") args = append(args, q.TenantID) if !q.StartTime.IsZero() { conditions = append(conditions, "timestamp >= ?") args = append(args, q.StartTime) } if !q.EndTime.IsZero() { conditions = append(conditions, "timestamp <= ?") args = append(args, q.EndTime) } if q.UserID != "" { conditions = append(conditions, "user_id = ?") args = append(args, q.UserID) } if q.Provider != "" { conditions = append(conditions, "provider = ?") args = append(args, q.Provider) } sensitivityOrder := map[string]int{"none": 0, "low": 1, "medium": 2, "high": 3, "critical": 4} if _, ok := sensitivityOrder[q.MinSensitivity]; ok && q.MinSensitivity != "" { levels := []string{} minLvl := sensitivityOrder[q.MinSensitivity] for lvl, ord := range sensitivityOrder { if ord >= minLvl { levels = append(levels, "'"+lvl+"'") } } conditions = append(conditions, "sensitivity_level IN ("+strings.Join(levels, ",")+")") } where := strings.Join(conditions, " AND ") query := fmt.Sprintf( "SELECT request_id, tenant_id, user_id, timestamp, model_requested, model_used, provider, "+ "department, user_role, prompt_hash, response_hash, sensitivity_level, "+ "token_input, token_output, token_total, cost_usd, latency_ms, status, "+ "error_type, pii_entity_count, stream FROM audit_logs WHERE %s "+ "ORDER BY timestamp DESC LIMIT %d OFFSET %d", where, limit, offset, ) rows, err := c.conn.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("clickhouse: query logs: %w", err) } defer rows.Close() var entries []AuditEntry for rows.Next() { var e AuditEntry var tokenIn, tokenOut, tokenTotal uint32 var latencyMs uint32 var piiCount uint16 if err := rows.Scan( &e.RequestID, &e.TenantID, &e.UserID, &e.Timestamp, &e.ModelRequested, &e.ModelUsed, &e.Provider, &e.Department, &e.UserRole, &e.PromptHash, &e.ResponseHash, &e.SensitivityLevel, &tokenIn, &tokenOut, &tokenTotal, &e.CostUSD, &latencyMs, &e.Status, &e.ErrorType, &piiCount, &e.Stream, ); err != nil { return nil, fmt.Errorf("clickhouse: scan: %w", err) } e.TokenInput = int(tokenIn) e.TokenOutput = int(tokenOut) e.TokenTotal = int(tokenTotal) e.LatencyMs = int(latencyMs) e.PIIEntityCount = int(piiCount) // prompt_anonymized is intentionally excluded from query results. entries = append(entries, e) } return &AuditResult{Data: entries, Total: len(entries)}, nil } func (c *ClickHouseLogger) QueryCosts(ctx context.Context, q CostQuery) (*CostResult, error) { groupField := "provider" switch q.GroupBy { case "model": groupField = "model_used" case "department": groupField = "department" } var conditions []string var args []interface{} conditions = append(conditions, "tenant_id = ?") args = append(args, q.TenantID) if !q.StartTime.IsZero() { conditions = append(conditions, "timestamp >= ?") args = append(args, q.StartTime) } if !q.EndTime.IsZero() { conditions = append(conditions, "timestamp <= ?") args = append(args, q.EndTime) } where := strings.Join(conditions, " AND ") query := fmt.Sprintf( "SELECT %s, sum(token_total), sum(cost_usd), count() FROM audit_logs WHERE %s GROUP BY %s ORDER BY %s", groupField, where, groupField, groupField, ) rows, err := c.conn.Query(ctx, query, args...) if err != nil { return nil, fmt.Errorf("clickhouse: query costs: %w", err) } defer rows.Close() var data []CostSummary for rows.Next() { var s CostSummary var tokens uint64 var count uint64 if err := rows.Scan(&s.Key, &tokens, &s.TotalCostUSD, &count); err != nil { return nil, fmt.Errorf("clickhouse: scan cost: %w", err) } s.TotalTokens = int(tokens) s.RequestCount = int(count) data = append(data, s) } sort.Slice(data, func(i, j int) bool { return data[i].Key < data[j].Key }) return &CostResult{Data: data}, nil }