// Package openai implements the provider.Adapter interface for the OpenAI API. package openai import ( "bufio" "bytes" "context" "encoding/json" "fmt" "net" "net/http" "time" "github.com/veylant/ia-gateway/internal/apierror" "github.com/veylant/ia-gateway/internal/provider" ) // Config holds the OpenAI adapter configuration. type Config struct { APIKey string BaseURL string TimeoutSeconds int MaxConns int } // Adapter implements provider.Adapter for the OpenAI API. type Adapter struct { cfg Config regularClient *http.Client // used for non-streaming requests (has full Timeout) streamingClient *http.Client // used for streaming (no full Timeout; transport-level only) } // New creates a new OpenAI Adapter with two separate HTTP client pools: // - regularClient: includes a full request timeout (safe for non-streaming) // - streamingClient: omits the per-request Timeout so the connection stays // open for the entire SSE stream duration; dial and header timeouts still apply func New(cfg Config) *Adapter { if cfg.BaseURL == "" { cfg.BaseURL = "https://api.openai.com/v1" } dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } transport := &http.Transport{ DialContext: dialer.DialContext, TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: 30 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConns: cfg.MaxConns, MaxIdleConnsPerHost: cfg.MaxConns, IdleConnTimeout: 90 * time.Second, DisableCompression: false, } timeout := time.Duration(cfg.TimeoutSeconds) * time.Second if cfg.TimeoutSeconds == 0 { timeout = 30 * time.Second } return &Adapter{ cfg: cfg, regularClient: &http.Client{ Transport: transport, Timeout: timeout, }, // Streaming client shares the same Transport (connection pool) but // has no per-request Timeout. streamingClient: &http.Client{ Transport: transport, }, } } // Validate checks that req is well-formed for the OpenAI adapter. func (a *Adapter) Validate(req *provider.ChatRequest) error { if req.Model == "" { return apierror.NewBadRequestError("field 'model' is required") } if len(req.Messages) == 0 { return apierror.NewBadRequestError("field 'messages' must not be empty") } return nil } // Send performs a non-streaming POST /chat/completions request to OpenAI. func (a *Adapter) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) { body, err := json.Marshal(req) if err != nil { return nil, apierror.NewInternalError(fmt.Sprintf("marshaling request: %v", err)) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.cfg.BaseURL+"/chat/completions", bytes.NewReader(body), ) if err != nil { return nil, apierror.NewInternalError(fmt.Sprintf("building request: %v", err)) } a.setHeaders(httpReq) resp, err := a.regularClient.Do(httpReq) if err != nil { if ctx.Err() != nil { return nil, apierror.NewTimeoutError("upstream request timed out") } return nil, apierror.NewUpstreamError(fmt.Sprintf("calling OpenAI: %v", err)) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return nil, mapUpstreamStatus(resp.StatusCode) } var chatResp provider.ChatResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { return nil, apierror.NewUpstreamError(fmt.Sprintf("decoding response: %v", err)) } return &chatResp, nil } // Stream performs a streaming POST /chat/completions request and pipes the SSE // chunks directly to w. Callers must set the Content-Type and other SSE headers // before invoking Stream (the proxy handler does this). func (a *Adapter) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error { // Force stream flag in the upstream request. req.Stream = true body, err := json.Marshal(req) if err != nil { return apierror.NewInternalError(fmt.Sprintf("marshaling stream request: %v", err)) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.cfg.BaseURL+"/chat/completions", bytes.NewReader(body), ) if err != nil { return apierror.NewInternalError(fmt.Sprintf("building stream request: %v", err)) } a.setHeaders(httpReq) resp, err := a.streamingClient.Do(httpReq) if err != nil { if ctx.Err() != nil { return apierror.NewTimeoutError("upstream stream timed out") } return apierror.NewUpstreamError(fmt.Sprintf("opening stream to OpenAI: %v", err)) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return mapUpstreamStatus(resp.StatusCode) } return PipeSSE(bufio.NewScanner(resp.Body), w) } // HealthCheck verifies that the OpenAI API is reachable by calling GET /models. func (a *Adapter) HealthCheck(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, a.cfg.BaseURL+"/models", nil) if err != nil { return fmt.Errorf("building health check request: %w", err) } a.setHeaders(req) resp, err := a.regularClient.Do(req) if err != nil { return fmt.Errorf("OpenAI health check failed: %w", err) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return fmt.Errorf("OpenAI health check returned HTTP %d", resp.StatusCode) } return nil } // setHeaders attaches the required OpenAI authentication and content headers. func (a *Adapter) setHeaders(req *http.Request) { req.Header.Set("Authorization", "Bearer "+a.cfg.APIKey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") } // mapUpstreamStatus converts an OpenAI HTTP status code to a typed APIError. func mapUpstreamStatus(status int) *apierror.APIError { switch status { case http.StatusUnauthorized: return apierror.NewAuthError("upstream rejected the API key") case http.StatusTooManyRequests: return apierror.NewRateLimitError("OpenAI rate limit reached") case http.StatusBadRequest: return apierror.NewBadRequestError("OpenAI rejected the request") case http.StatusGatewayTimeout, http.StatusRequestTimeout: return apierror.NewTimeoutError("OpenAI request timed out") default: return apierror.NewUpstreamError(fmt.Sprintf("OpenAI returned HTTP %d", status)) } }