// Package azure implements provider.Adapter for the Azure OpenAI Service. // Azure uses the same request/response format as OpenAI but differs in: // - URL structure: https://{resource}.openai.azure.com/openai/deployments/{deployment}/... // - Authentication header: "api-key" instead of "Authorization: Bearer" package azure import ( "bufio" "bytes" "context" "encoding/json" "fmt" "net" "net/http" "time" "github.com/veylant/ia-gateway/internal/apierror" "github.com/veylant/ia-gateway/internal/provider" "github.com/veylant/ia-gateway/internal/provider/openai" ) // Config holds Azure OpenAI adapter configuration. type Config struct { APIKey string ResourceName string // Azure resource name, e.g. "my-azure-openai" DeploymentID string // Deployment name, e.g. "gpt-4o" APIVersion string // e.g. "2024-02-01" TimeoutSeconds int MaxConns int } // Adapter implements provider.Adapter for Azure OpenAI Service. type Adapter struct { cfg Config testBaseURL string // non-empty in tests: overrides Azure URL construction regularClient *http.Client streamingClient *http.Client } // New creates an Azure OpenAI adapter. func New(cfg Config) *Adapter { return NewWithBaseURL("", cfg) } // NewWithBaseURL creates an adapter with a custom base URL — used in tests to // point the adapter at a mock server instead of the real Azure endpoint. func NewWithBaseURL(baseURL string, cfg Config) *Adapter { if cfg.APIVersion == "" { cfg.APIVersion = "2024-02-01" } dialer := &net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, } transport := &http.Transport{ DialContext: dialer.DialContext, TLSHandshakeTimeout: 10 * time.Second, ResponseHeaderTimeout: 30 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConns: cfg.MaxConns, MaxIdleConnsPerHost: cfg.MaxConns, IdleConnTimeout: 90 * time.Second, } timeout := time.Duration(cfg.TimeoutSeconds) * time.Second if cfg.TimeoutSeconds == 0 { timeout = 30 * time.Second } return &Adapter{ cfg: cfg, testBaseURL: baseURL, regularClient: &http.Client{ Transport: transport, Timeout: timeout, }, streamingClient: &http.Client{ Transport: transport, }, } } // Validate checks that req is well-formed for the Azure adapter. func (a *Adapter) Validate(req *provider.ChatRequest) error { if req.Model == "" { return apierror.NewBadRequestError("field 'model' is required") } if len(req.Messages) == 0 { return apierror.NewBadRequestError("field 'messages' must not be empty") } return nil } // Send performs a non-streaming chat completion request to Azure OpenAI. func (a *Adapter) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) { body, err := json.Marshal(req) if err != nil { return nil, apierror.NewInternalError(fmt.Sprintf("marshaling request: %v", err)) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.chatURL(), bytes.NewReader(body)) if err != nil { return nil, apierror.NewInternalError(fmt.Sprintf("building request: %v", err)) } a.setHeaders(httpReq) resp, err := a.regularClient.Do(httpReq) if err != nil { if ctx.Err() != nil { return nil, apierror.NewTimeoutError("Azure OpenAI request timed out") } return nil, apierror.NewUpstreamError(fmt.Sprintf("calling Azure OpenAI: %v", err)) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return nil, mapUpstreamStatus(resp.StatusCode) } var chatResp provider.ChatResponse if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil { return nil, apierror.NewUpstreamError(fmt.Sprintf("decoding response: %v", err)) } return &chatResp, nil } // Stream performs a streaming chat completion request and pipes SSE chunks to w. func (a *Adapter) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error { req.Stream = true body, err := json.Marshal(req) if err != nil { return apierror.NewInternalError(fmt.Sprintf("marshaling stream request: %v", err)) } httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.chatURL(), bytes.NewReader(body)) if err != nil { return apierror.NewInternalError(fmt.Sprintf("building stream request: %v", err)) } a.setHeaders(httpReq) resp, err := a.streamingClient.Do(httpReq) if err != nil { if ctx.Err() != nil { return apierror.NewTimeoutError("Azure OpenAI stream timed out") } return apierror.NewUpstreamError(fmt.Sprintf("opening stream to Azure OpenAI: %v", err)) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return mapUpstreamStatus(resp.StatusCode) } // Azure SSE format is identical to OpenAI — reuse PipeSSE. return openai.PipeSSE(bufio.NewScanner(resp.Body), w) } // HealthCheck verifies Azure OpenAI is reachable by listing deployments. func (a *Adapter) HealthCheck(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() var url string if a.testBaseURL != "" { url = a.testBaseURL + fmt.Sprintf("/openai/deployments?api-version=%s", a.cfg.APIVersion) } else { url = fmt.Sprintf( "https://%s.openai.azure.com/openai/deployments?api-version=%s", a.cfg.ResourceName, a.cfg.APIVersion, ) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return fmt.Errorf("building Azure health check request: %w", err) } a.setHeaders(req) resp, err := a.regularClient.Do(req) if err != nil { return fmt.Errorf("Azure OpenAI health check failed: %w", err) } defer resp.Body.Close() //nolint:errcheck if resp.StatusCode != http.StatusOK { return fmt.Errorf("Azure OpenAI health check returned HTTP %d", resp.StatusCode) } return nil } // chatURL builds the Azure OpenAI completions endpoint URL. // In tests, testBaseURL overrides Azure URL construction for mock server injection. func (a *Adapter) chatURL() string { if a.testBaseURL != "" { return a.testBaseURL + fmt.Sprintf( "/openai/deployments/%s/chat/completions?api-version=%s", a.cfg.DeploymentID, a.cfg.APIVersion, ) } return fmt.Sprintf( "https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s", a.cfg.ResourceName, a.cfg.DeploymentID, a.cfg.APIVersion, ) } // setHeaders attaches Azure authentication and content headers. // Azure uses "api-key" instead of "Authorization: Bearer". func (a *Adapter) setHeaders(req *http.Request) { req.Header.Set("api-key", a.cfg.APIKey) req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") } func mapUpstreamStatus(status int) *apierror.APIError { switch status { case http.StatusUnauthorized: return apierror.NewAuthError("Azure OpenAI rejected the API key") case http.StatusTooManyRequests: return apierror.NewRateLimitError("Azure OpenAI rate limit reached") case http.StatusBadRequest: return apierror.NewBadRequestError("Azure OpenAI rejected the request") case http.StatusGatewayTimeout, http.StatusRequestTimeout: return apierror.NewTimeoutError("Azure OpenAI request timed out") default: return apierror.NewUpstreamError(fmt.Sprintf("Azure OpenAI returned HTTP %d", status)) } }