veylant/internal/ratelimit/pg_store.go
2026-02-23 13:35:04 +01:00

102 lines
3.2 KiB
Go

package ratelimit
import (
"context"
"database/sql"
"errors"
"fmt"
"go.uber.org/zap"
)
// ErrNotFound is returned when no config exists for the requested tenant.
var ErrNotFound = errors.New("rate limit config not found")
// Store persists per-tenant rate limit configurations in PostgreSQL.
type Store struct {
db *sql.DB
logger *zap.Logger
}
// NewStore creates a new PostgreSQL-backed rate limit store.
func NewStore(db *sql.DB, logger *zap.Logger) *Store {
return &Store{db: db, logger: logger}
}
// List returns all per-tenant rate limit configurations.
func (s *Store) List(ctx context.Context) ([]RateLimitConfig, error) {
rows, err := s.db.QueryContext(ctx,
`SELECT tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled
FROM rate_limit_configs ORDER BY tenant_id`)
if err != nil {
return nil, fmt.Errorf("ratelimit: list: %w", err)
}
defer rows.Close()
var cfgs []RateLimitConfig
for rows.Next() {
var c RateLimitConfig
if err := rows.Scan(&c.TenantID, &c.RequestsPerMin, &c.BurstSize,
&c.UserRPM, &c.UserBurst, &c.IsEnabled); err != nil {
return nil, fmt.Errorf("ratelimit: scan: %w", err)
}
cfgs = append(cfgs, c)
}
return cfgs, rows.Err()
}
// Get returns the config for a specific tenant or ErrNotFound.
func (s *Store) Get(ctx context.Context, tenantID string) (RateLimitConfig, error) {
var c RateLimitConfig
err := s.db.QueryRowContext(ctx,
`SELECT tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled
FROM rate_limit_configs WHERE tenant_id = $1`, tenantID).
Scan(&c.TenantID, &c.RequestsPerMin, &c.BurstSize, &c.UserRPM, &c.UserBurst, &c.IsEnabled)
if errors.Is(err, sql.ErrNoRows) {
return RateLimitConfig{}, ErrNotFound
}
if err != nil {
return RateLimitConfig{}, fmt.Errorf("ratelimit: get: %w", err)
}
return c, nil
}
// Upsert creates or updates the rate limit config for a tenant.
func (s *Store) Upsert(ctx context.Context, cfg RateLimitConfig) (RateLimitConfig, error) {
var out RateLimitConfig
err := s.db.QueryRowContext(ctx, `
INSERT INTO rate_limit_configs
(tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, NOW())
ON CONFLICT (tenant_id) DO UPDATE SET
requests_per_min = EXCLUDED.requests_per_min,
burst_size = EXCLUDED.burst_size,
user_rpm = EXCLUDED.user_rpm,
user_burst = EXCLUDED.user_burst,
is_enabled = EXCLUDED.is_enabled,
updated_at = NOW()
RETURNING tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled`,
cfg.TenantID, cfg.RequestsPerMin, cfg.BurstSize,
cfg.UserRPM, cfg.UserBurst, cfg.IsEnabled,
).Scan(&out.TenantID, &out.RequestsPerMin, &out.BurstSize,
&out.UserRPM, &out.UserBurst, &out.IsEnabled)
if err != nil {
return RateLimitConfig{}, fmt.Errorf("ratelimit: upsert: %w", err)
}
return out, nil
}
// Delete removes the per-tenant override. Returns ErrNotFound if absent.
func (s *Store) Delete(ctx context.Context, tenantID string) error {
res, err := s.db.ExecContext(ctx,
`DELETE FROM rate_limit_configs WHERE tenant_id = $1`, tenantID)
if err != nil {
return fmt.Errorf("ratelimit: delete: %w", err)
}
n, _ := res.RowsAffected()
if n == 0 {
return ErrNotFound
}
return nil
}