254 lines
7.6 KiB
Go
254 lines
7.6 KiB
Go
package auditlog
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"os"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/ClickHouse/clickhouse-go/v2"
|
|
"github.com/ClickHouse/clickhouse-go/v2/lib/driver"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ClickHouseLogger implements Logger + Flusher backed by a ClickHouse connection.
|
|
// Query/QueryCosts perform synchronous CH queries for the admin API.
|
|
// Log() is non-blocking: entries are queued in BatchWriter (not directly here).
|
|
type ClickHouseLogger struct {
|
|
conn driver.Conn
|
|
logger *zap.Logger
|
|
bw *BatchWriter
|
|
}
|
|
|
|
// NewClickHouseLogger opens a ClickHouse native connection from a DSN string
|
|
// (clickhouse://user:pass@host:9000/database) and returns a ClickHouseLogger.
|
|
// The caller must call Start() and defer Stop().
|
|
func NewClickHouseLogger(dsn string, maxConns, dialTimeoutSec int, logger *zap.Logger) (*ClickHouseLogger, error) {
|
|
opts, err := clickhouse.ParseDSN(dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("clickhouse: parse DSN: %w", err)
|
|
}
|
|
if maxConns > 0 {
|
|
opts.MaxOpenConns = maxConns
|
|
}
|
|
if dialTimeoutSec > 0 {
|
|
opts.DialTimeout = time.Duration(dialTimeoutSec) * time.Second
|
|
}
|
|
|
|
conn, err := clickhouse.Open(opts)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("clickhouse: open: %w", err)
|
|
}
|
|
if err := conn.Ping(context.Background()); err != nil {
|
|
return nil, fmt.Errorf("clickhouse: ping: %w", err)
|
|
}
|
|
|
|
ch := &ClickHouseLogger{conn: conn, logger: logger}
|
|
ch.bw = NewBatchWriter(ch, logger)
|
|
return ch, nil
|
|
}
|
|
|
|
// ApplyDDL reads and executes the ClickHouse DDL file at startup (idempotent).
|
|
func (c *ClickHouseLogger) ApplyDDL(sqlPath string) error {
|
|
data, err := os.ReadFile(sqlPath)
|
|
if err != nil {
|
|
return fmt.Errorf("clickhouse: read DDL %s: %w", sqlPath, err)
|
|
}
|
|
// Split on semicolons to handle multi-statement files.
|
|
for _, stmt := range strings.Split(string(data), ";") {
|
|
stmt = strings.TrimSpace(stmt)
|
|
if stmt == "" || strings.HasPrefix(stmt, "--") {
|
|
continue
|
|
}
|
|
if err := c.conn.Exec(context.Background(), stmt); err != nil {
|
|
return fmt.Errorf("clickhouse: exec DDL: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ─── Logger interface ─────────────────────────────────────────────────────────
|
|
|
|
func (c *ClickHouseLogger) Log(entry AuditEntry) { c.bw.Log(entry) }
|
|
func (c *ClickHouseLogger) Start() { c.bw.Start() }
|
|
func (c *ClickHouseLogger) Stop() { c.bw.Stop() }
|
|
|
|
// ─── Flusher interface ────────────────────────────────────────────────────────
|
|
|
|
func (c *ClickHouseLogger) InsertBatch(ctx context.Context, entries []AuditEntry) error {
|
|
batch, err := c.conn.PrepareBatch(ctx, "INSERT INTO audit_logs")
|
|
if err != nil {
|
|
return fmt.Errorf("clickhouse: prepare batch: %w", err)
|
|
}
|
|
for _, e := range entries {
|
|
if err := batch.Append(
|
|
e.RequestID,
|
|
e.TenantID,
|
|
e.UserID,
|
|
e.Timestamp,
|
|
e.ModelRequested,
|
|
e.ModelUsed,
|
|
e.Provider,
|
|
e.Department,
|
|
e.UserRole,
|
|
e.PromptHash,
|
|
e.ResponseHash,
|
|
e.PromptAnonymized,
|
|
e.SensitivityLevel,
|
|
uint32(e.TokenInput),
|
|
uint32(e.TokenOutput),
|
|
uint32(e.TokenTotal),
|
|
e.CostUSD,
|
|
uint32(e.LatencyMs),
|
|
e.Status,
|
|
e.ErrorType,
|
|
uint16(e.PIIEntityCount),
|
|
e.Stream,
|
|
); err != nil {
|
|
return fmt.Errorf("clickhouse: append row: %w", err)
|
|
}
|
|
}
|
|
return batch.Send()
|
|
}
|
|
|
|
// ─── Query ────────────────────────────────────────────────────────────────────
|
|
|
|
func (c *ClickHouseLogger) Query(ctx context.Context, q AuditQuery) (*AuditResult, error) {
|
|
limit := q.Limit
|
|
if limit <= 0 || limit > 200 {
|
|
limit = 50
|
|
}
|
|
offset := q.Offset
|
|
|
|
var conditions []string
|
|
var args []interface{}
|
|
|
|
conditions = append(conditions, "tenant_id = ?")
|
|
args = append(args, q.TenantID)
|
|
|
|
if !q.StartTime.IsZero() {
|
|
conditions = append(conditions, "timestamp >= ?")
|
|
args = append(args, q.StartTime)
|
|
}
|
|
if !q.EndTime.IsZero() {
|
|
conditions = append(conditions, "timestamp <= ?")
|
|
args = append(args, q.EndTime)
|
|
}
|
|
if q.UserID != "" {
|
|
conditions = append(conditions, "user_id = ?")
|
|
args = append(args, q.UserID)
|
|
}
|
|
if q.Provider != "" {
|
|
conditions = append(conditions, "provider = ?")
|
|
args = append(args, q.Provider)
|
|
}
|
|
|
|
sensitivityOrder := map[string]int{"none": 0, "low": 1, "medium": 2, "high": 3, "critical": 4}
|
|
if _, ok := sensitivityOrder[q.MinSensitivity]; ok && q.MinSensitivity != "" {
|
|
levels := []string{}
|
|
minLvl := sensitivityOrder[q.MinSensitivity]
|
|
for lvl, ord := range sensitivityOrder {
|
|
if ord >= minLvl {
|
|
levels = append(levels, "'"+lvl+"'")
|
|
}
|
|
}
|
|
conditions = append(conditions, "sensitivity_level IN ("+strings.Join(levels, ",")+")")
|
|
}
|
|
|
|
where := strings.Join(conditions, " AND ")
|
|
query := fmt.Sprintf(
|
|
"SELECT request_id, tenant_id, user_id, timestamp, model_requested, model_used, provider, "+
|
|
"department, user_role, prompt_hash, response_hash, sensitivity_level, "+
|
|
"token_input, token_output, token_total, cost_usd, latency_ms, status, "+
|
|
"error_type, pii_entity_count, stream FROM audit_logs WHERE %s "+
|
|
"ORDER BY timestamp DESC LIMIT %d OFFSET %d",
|
|
where, limit, offset,
|
|
)
|
|
|
|
rows, err := c.conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("clickhouse: query logs: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entries []AuditEntry
|
|
for rows.Next() {
|
|
var e AuditEntry
|
|
var tokenIn, tokenOut, tokenTotal uint32
|
|
var latencyMs uint32
|
|
var piiCount uint16
|
|
if err := rows.Scan(
|
|
&e.RequestID, &e.TenantID, &e.UserID, &e.Timestamp,
|
|
&e.ModelRequested, &e.ModelUsed, &e.Provider,
|
|
&e.Department, &e.UserRole, &e.PromptHash, &e.ResponseHash,
|
|
&e.SensitivityLevel, &tokenIn, &tokenOut, &tokenTotal,
|
|
&e.CostUSD, &latencyMs, &e.Status, &e.ErrorType, &piiCount, &e.Stream,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("clickhouse: scan: %w", err)
|
|
}
|
|
e.TokenInput = int(tokenIn)
|
|
e.TokenOutput = int(tokenOut)
|
|
e.TokenTotal = int(tokenTotal)
|
|
e.LatencyMs = int(latencyMs)
|
|
e.PIIEntityCount = int(piiCount)
|
|
// prompt_anonymized is intentionally excluded from query results.
|
|
entries = append(entries, e)
|
|
}
|
|
|
|
return &AuditResult{Data: entries, Total: len(entries)}, nil
|
|
}
|
|
|
|
func (c *ClickHouseLogger) QueryCosts(ctx context.Context, q CostQuery) (*CostResult, error) {
|
|
groupField := "provider"
|
|
switch q.GroupBy {
|
|
case "model":
|
|
groupField = "model_used"
|
|
case "department":
|
|
groupField = "department"
|
|
}
|
|
|
|
var conditions []string
|
|
var args []interface{}
|
|
|
|
conditions = append(conditions, "tenant_id = ?")
|
|
args = append(args, q.TenantID)
|
|
|
|
if !q.StartTime.IsZero() {
|
|
conditions = append(conditions, "timestamp >= ?")
|
|
args = append(args, q.StartTime)
|
|
}
|
|
if !q.EndTime.IsZero() {
|
|
conditions = append(conditions, "timestamp <= ?")
|
|
args = append(args, q.EndTime)
|
|
}
|
|
|
|
where := strings.Join(conditions, " AND ")
|
|
query := fmt.Sprintf(
|
|
"SELECT %s, sum(token_total), sum(cost_usd), count() FROM audit_logs WHERE %s GROUP BY %s ORDER BY %s",
|
|
groupField, where, groupField, groupField,
|
|
)
|
|
|
|
rows, err := c.conn.Query(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("clickhouse: query costs: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var data []CostSummary
|
|
for rows.Next() {
|
|
var s CostSummary
|
|
var tokens uint64
|
|
var count uint64
|
|
if err := rows.Scan(&s.Key, &tokens, &s.TotalCostUSD, &count); err != nil {
|
|
return nil, fmt.Errorf("clickhouse: scan cost: %w", err)
|
|
}
|
|
s.TotalTokens = int(tokens)
|
|
s.RequestCount = int(count)
|
|
data = append(data, s)
|
|
}
|
|
sort.Slice(data, func(i, j int) bool { return data[i].Key < data[j].Key })
|
|
return &CostResult{Data: data}, nil
|
|
}
|