veylant/internal/provider/openai/adapter.go
2026-02-23 13:35:04 +01:00

206 lines
6.2 KiB
Go

// Package openai implements the provider.Adapter interface for the OpenAI API.
package openai
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"net"
"net/http"
"time"
"github.com/veylant/ia-gateway/internal/apierror"
"github.com/veylant/ia-gateway/internal/provider"
)
// Config holds the OpenAI adapter configuration.
type Config struct {
APIKey string
BaseURL string
TimeoutSeconds int
MaxConns int
}
// Adapter implements provider.Adapter for the OpenAI API.
type Adapter struct {
cfg Config
regularClient *http.Client // used for non-streaming requests (has full Timeout)
streamingClient *http.Client // used for streaming (no full Timeout; transport-level only)
}
// New creates a new OpenAI Adapter with two separate HTTP client pools:
// - regularClient: includes a full request timeout (safe for non-streaming)
// - streamingClient: omits the per-request Timeout so the connection stays
// open for the entire SSE stream duration; dial and header timeouts still apply
func New(cfg Config) *Adapter {
if cfg.BaseURL == "" {
cfg.BaseURL = "https://api.openai.com/v1"
}
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
transport := &http.Transport{
DialContext: dialer.DialContext,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
MaxIdleConns: cfg.MaxConns,
MaxIdleConnsPerHost: cfg.MaxConns,
IdleConnTimeout: 90 * time.Second,
DisableCompression: false,
}
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
if cfg.TimeoutSeconds == 0 {
timeout = 30 * time.Second
}
return &Adapter{
cfg: cfg,
regularClient: &http.Client{
Transport: transport,
Timeout: timeout,
},
// Streaming client shares the same Transport (connection pool) but
// has no per-request Timeout.
streamingClient: &http.Client{
Transport: transport,
},
}
}
// Validate checks that req is well-formed for the OpenAI adapter.
func (a *Adapter) Validate(req *provider.ChatRequest) error {
if req.Model == "" {
return apierror.NewBadRequestError("field 'model' is required")
}
if len(req.Messages) == 0 {
return apierror.NewBadRequestError("field 'messages' must not be empty")
}
return nil
}
// Send performs a non-streaming POST /chat/completions request to OpenAI.
func (a *Adapter) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) {
body, err := json.Marshal(req)
if err != nil {
return nil, apierror.NewInternalError(fmt.Sprintf("marshaling request: %v", err))
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
a.cfg.BaseURL+"/chat/completions",
bytes.NewReader(body),
)
if err != nil {
return nil, apierror.NewInternalError(fmt.Sprintf("building request: %v", err))
}
a.setHeaders(httpReq)
resp, err := a.regularClient.Do(httpReq)
if err != nil {
if ctx.Err() != nil {
return nil, apierror.NewTimeoutError("upstream request timed out")
}
return nil, apierror.NewUpstreamError(fmt.Sprintf("calling OpenAI: %v", err))
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return nil, mapUpstreamStatus(resp.StatusCode)
}
var chatResp provider.ChatResponse
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
return nil, apierror.NewUpstreamError(fmt.Sprintf("decoding response: %v", err))
}
return &chatResp, nil
}
// Stream performs a streaming POST /chat/completions request and pipes the SSE
// chunks directly to w. Callers must set the Content-Type and other SSE headers
// before invoking Stream (the proxy handler does this).
func (a *Adapter) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error {
// Force stream flag in the upstream request.
req.Stream = true
body, err := json.Marshal(req)
if err != nil {
return apierror.NewInternalError(fmt.Sprintf("marshaling stream request: %v", err))
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost,
a.cfg.BaseURL+"/chat/completions",
bytes.NewReader(body),
)
if err != nil {
return apierror.NewInternalError(fmt.Sprintf("building stream request: %v", err))
}
a.setHeaders(httpReq)
resp, err := a.streamingClient.Do(httpReq)
if err != nil {
if ctx.Err() != nil {
return apierror.NewTimeoutError("upstream stream timed out")
}
return apierror.NewUpstreamError(fmt.Sprintf("opening stream to OpenAI: %v", err))
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return mapUpstreamStatus(resp.StatusCode)
}
return PipeSSE(bufio.NewScanner(resp.Body), w)
}
// HealthCheck verifies that the OpenAI API is reachable by calling GET /models.
func (a *Adapter) HealthCheck(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.cfg.BaseURL+"/models", nil)
if err != nil {
return fmt.Errorf("building health check request: %w", err)
}
a.setHeaders(req)
resp, err := a.regularClient.Do(req)
if err != nil {
return fmt.Errorf("OpenAI health check failed: %w", err)
}
defer resp.Body.Close() //nolint:errcheck
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("OpenAI health check returned HTTP %d", resp.StatusCode)
}
return nil
}
// setHeaders attaches the required OpenAI authentication and content headers.
func (a *Adapter) setHeaders(req *http.Request) {
req.Header.Set("Authorization", "Bearer "+a.cfg.APIKey)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
}
// mapUpstreamStatus converts an OpenAI HTTP status code to a typed APIError.
func mapUpstreamStatus(status int) *apierror.APIError {
switch status {
case http.StatusUnauthorized:
return apierror.NewAuthError("upstream rejected the API key")
case http.StatusTooManyRequests:
return apierror.NewRateLimitError("OpenAI rate limit reached")
case http.StatusBadRequest:
return apierror.NewBadRequestError("OpenAI rejected the request")
case http.StatusGatewayTimeout, http.StatusRequestTimeout:
return apierror.NewTimeoutError("OpenAI request timed out")
default:
return apierror.NewUpstreamError(fmt.Sprintf("OpenAI returned HTTP %d", status))
}
}