102 lines
3.2 KiB
Go
102 lines
3.2 KiB
Go
package ratelimit
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
// ErrNotFound is returned when no config exists for the requested tenant.
|
|
var ErrNotFound = errors.New("rate limit config not found")
|
|
|
|
// Store persists per-tenant rate limit configurations in PostgreSQL.
|
|
type Store struct {
|
|
db *sql.DB
|
|
logger *zap.Logger
|
|
}
|
|
|
|
// NewStore creates a new PostgreSQL-backed rate limit store.
|
|
func NewStore(db *sql.DB, logger *zap.Logger) *Store {
|
|
return &Store{db: db, logger: logger}
|
|
}
|
|
|
|
// List returns all per-tenant rate limit configurations.
|
|
func (s *Store) List(ctx context.Context) ([]RateLimitConfig, error) {
|
|
rows, err := s.db.QueryContext(ctx,
|
|
`SELECT tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled
|
|
FROM rate_limit_configs ORDER BY tenant_id`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("ratelimit: list: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var cfgs []RateLimitConfig
|
|
for rows.Next() {
|
|
var c RateLimitConfig
|
|
if err := rows.Scan(&c.TenantID, &c.RequestsPerMin, &c.BurstSize,
|
|
&c.UserRPM, &c.UserBurst, &c.IsEnabled); err != nil {
|
|
return nil, fmt.Errorf("ratelimit: scan: %w", err)
|
|
}
|
|
cfgs = append(cfgs, c)
|
|
}
|
|
return cfgs, rows.Err()
|
|
}
|
|
|
|
// Get returns the config for a specific tenant or ErrNotFound.
|
|
func (s *Store) Get(ctx context.Context, tenantID string) (RateLimitConfig, error) {
|
|
var c RateLimitConfig
|
|
err := s.db.QueryRowContext(ctx,
|
|
`SELECT tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled
|
|
FROM rate_limit_configs WHERE tenant_id = $1`, tenantID).
|
|
Scan(&c.TenantID, &c.RequestsPerMin, &c.BurstSize, &c.UserRPM, &c.UserBurst, &c.IsEnabled)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return RateLimitConfig{}, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return RateLimitConfig{}, fmt.Errorf("ratelimit: get: %w", err)
|
|
}
|
|
return c, nil
|
|
}
|
|
|
|
// Upsert creates or updates the rate limit config for a tenant.
|
|
func (s *Store) Upsert(ctx context.Context, cfg RateLimitConfig) (RateLimitConfig, error) {
|
|
var out RateLimitConfig
|
|
err := s.db.QueryRowContext(ctx, `
|
|
INSERT INTO rate_limit_configs
|
|
(tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled, updated_at)
|
|
VALUES ($1, $2, $3, $4, $5, $6, NOW())
|
|
ON CONFLICT (tenant_id) DO UPDATE SET
|
|
requests_per_min = EXCLUDED.requests_per_min,
|
|
burst_size = EXCLUDED.burst_size,
|
|
user_rpm = EXCLUDED.user_rpm,
|
|
user_burst = EXCLUDED.user_burst,
|
|
is_enabled = EXCLUDED.is_enabled,
|
|
updated_at = NOW()
|
|
RETURNING tenant_id, requests_per_min, burst_size, user_rpm, user_burst, is_enabled`,
|
|
cfg.TenantID, cfg.RequestsPerMin, cfg.BurstSize,
|
|
cfg.UserRPM, cfg.UserBurst, cfg.IsEnabled,
|
|
).Scan(&out.TenantID, &out.RequestsPerMin, &out.BurstSize,
|
|
&out.UserRPM, &out.UserBurst, &out.IsEnabled)
|
|
if err != nil {
|
|
return RateLimitConfig{}, fmt.Errorf("ratelimit: upsert: %w", err)
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// Delete removes the per-tenant override. Returns ErrNotFound if absent.
|
|
func (s *Store) Delete(ctx context.Context, tenantID string) error {
|
|
res, err := s.db.ExecContext(ctx,
|
|
`DELETE FROM rate_limit_configs WHERE tenant_id = $1`, tenantID)
|
|
if err != nil {
|
|
return fmt.Errorf("ratelimit: delete: %w", err)
|
|
}
|
|
n, _ := res.RowsAffected()
|
|
if n == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|