230 lines
7.1 KiB
Go
230 lines
7.1 KiB
Go
// Package azure implements provider.Adapter for the Azure OpenAI Service.
|
|
// Azure uses the same request/response format as OpenAI but differs in:
|
|
// - URL structure: https://{resource}.openai.azure.com/openai/deployments/{deployment}/...
|
|
// - Authentication header: "api-key" instead of "Authorization: Bearer"
|
|
package azure
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/veylant/ia-gateway/internal/apierror"
|
|
"github.com/veylant/ia-gateway/internal/provider"
|
|
"github.com/veylant/ia-gateway/internal/provider/openai"
|
|
)
|
|
|
|
// Config holds Azure OpenAI adapter configuration.
|
|
type Config struct {
|
|
APIKey string
|
|
ResourceName string // Azure resource name, e.g. "my-azure-openai"
|
|
DeploymentID string // Deployment name, e.g. "gpt-4o"
|
|
APIVersion string // e.g. "2024-02-01"
|
|
TimeoutSeconds int
|
|
MaxConns int
|
|
}
|
|
|
|
// Adapter implements provider.Adapter for Azure OpenAI Service.
|
|
type Adapter struct {
|
|
cfg Config
|
|
testBaseURL string // non-empty in tests: overrides Azure URL construction
|
|
regularClient *http.Client
|
|
streamingClient *http.Client
|
|
}
|
|
|
|
// New creates an Azure OpenAI adapter.
|
|
func New(cfg Config) *Adapter {
|
|
return NewWithBaseURL("", cfg)
|
|
}
|
|
|
|
// NewWithBaseURL creates an adapter with a custom base URL — used in tests to
|
|
// point the adapter at a mock server instead of the real Azure endpoint.
|
|
func NewWithBaseURL(baseURL string, cfg Config) *Adapter {
|
|
if cfg.APIVersion == "" {
|
|
cfg.APIVersion = "2024-02-01"
|
|
}
|
|
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
transport := &http.Transport{
|
|
DialContext: dialer.DialContext,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ResponseHeaderTimeout: 30 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
MaxIdleConns: cfg.MaxConns,
|
|
MaxIdleConnsPerHost: cfg.MaxConns,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
}
|
|
|
|
timeout := time.Duration(cfg.TimeoutSeconds) * time.Second
|
|
if cfg.TimeoutSeconds == 0 {
|
|
timeout = 30 * time.Second
|
|
}
|
|
|
|
return &Adapter{
|
|
cfg: cfg,
|
|
testBaseURL: baseURL,
|
|
regularClient: &http.Client{
|
|
Transport: transport,
|
|
Timeout: timeout,
|
|
},
|
|
streamingClient: &http.Client{
|
|
Transport: transport,
|
|
},
|
|
}
|
|
}
|
|
|
|
// Validate checks that req is well-formed for the Azure adapter.
|
|
func (a *Adapter) Validate(req *provider.ChatRequest) error {
|
|
if req.Model == "" {
|
|
return apierror.NewBadRequestError("field 'model' is required")
|
|
}
|
|
if len(req.Messages) == 0 {
|
|
return apierror.NewBadRequestError("field 'messages' must not be empty")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Send performs a non-streaming chat completion request to Azure OpenAI.
|
|
func (a *Adapter) Send(ctx context.Context, req *provider.ChatRequest) (*provider.ChatResponse, error) {
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return nil, apierror.NewInternalError(fmt.Sprintf("marshaling request: %v", err))
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.chatURL(), bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, apierror.NewInternalError(fmt.Sprintf("building request: %v", err))
|
|
}
|
|
a.setHeaders(httpReq)
|
|
|
|
resp, err := a.regularClient.Do(httpReq)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return nil, apierror.NewTimeoutError("Azure OpenAI request timed out")
|
|
}
|
|
return nil, apierror.NewUpstreamError(fmt.Sprintf("calling Azure OpenAI: %v", err))
|
|
}
|
|
defer resp.Body.Close() //nolint:errcheck
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return nil, mapUpstreamStatus(resp.StatusCode)
|
|
}
|
|
|
|
var chatResp provider.ChatResponse
|
|
if err := json.NewDecoder(resp.Body).Decode(&chatResp); err != nil {
|
|
return nil, apierror.NewUpstreamError(fmt.Sprintf("decoding response: %v", err))
|
|
}
|
|
return &chatResp, nil
|
|
}
|
|
|
|
// Stream performs a streaming chat completion request and pipes SSE chunks to w.
|
|
func (a *Adapter) Stream(ctx context.Context, req *provider.ChatRequest, w http.ResponseWriter) error {
|
|
req.Stream = true
|
|
|
|
body, err := json.Marshal(req)
|
|
if err != nil {
|
|
return apierror.NewInternalError(fmt.Sprintf("marshaling stream request: %v", err))
|
|
}
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, a.chatURL(), bytes.NewReader(body))
|
|
if err != nil {
|
|
return apierror.NewInternalError(fmt.Sprintf("building stream request: %v", err))
|
|
}
|
|
a.setHeaders(httpReq)
|
|
|
|
resp, err := a.streamingClient.Do(httpReq)
|
|
if err != nil {
|
|
if ctx.Err() != nil {
|
|
return apierror.NewTimeoutError("Azure OpenAI stream timed out")
|
|
}
|
|
return apierror.NewUpstreamError(fmt.Sprintf("opening stream to Azure OpenAI: %v", err))
|
|
}
|
|
defer resp.Body.Close() //nolint:errcheck
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return mapUpstreamStatus(resp.StatusCode)
|
|
}
|
|
|
|
// Azure SSE format is identical to OpenAI — reuse PipeSSE.
|
|
return openai.PipeSSE(bufio.NewScanner(resp.Body), w)
|
|
}
|
|
|
|
// HealthCheck verifies Azure OpenAI is reachable by listing deployments.
|
|
func (a *Adapter) HealthCheck(ctx context.Context) error {
|
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
|
defer cancel()
|
|
|
|
var url string
|
|
if a.testBaseURL != "" {
|
|
url = a.testBaseURL + fmt.Sprintf("/openai/deployments?api-version=%s", a.cfg.APIVersion)
|
|
} else {
|
|
url = fmt.Sprintf(
|
|
"https://%s.openai.azure.com/openai/deployments?api-version=%s",
|
|
a.cfg.ResourceName, a.cfg.APIVersion,
|
|
)
|
|
}
|
|
|
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
|
if err != nil {
|
|
return fmt.Errorf("building Azure health check request: %w", err)
|
|
}
|
|
a.setHeaders(req)
|
|
|
|
resp, err := a.regularClient.Do(req)
|
|
if err != nil {
|
|
return fmt.Errorf("Azure OpenAI health check failed: %w", err)
|
|
}
|
|
defer resp.Body.Close() //nolint:errcheck
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
return fmt.Errorf("Azure OpenAI health check returned HTTP %d", resp.StatusCode)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// chatURL builds the Azure OpenAI completions endpoint URL.
|
|
// In tests, testBaseURL overrides Azure URL construction for mock server injection.
|
|
func (a *Adapter) chatURL() string {
|
|
if a.testBaseURL != "" {
|
|
return a.testBaseURL + fmt.Sprintf(
|
|
"/openai/deployments/%s/chat/completions?api-version=%s",
|
|
a.cfg.DeploymentID, a.cfg.APIVersion,
|
|
)
|
|
}
|
|
return fmt.Sprintf(
|
|
"https://%s.openai.azure.com/openai/deployments/%s/chat/completions?api-version=%s",
|
|
a.cfg.ResourceName, a.cfg.DeploymentID, a.cfg.APIVersion,
|
|
)
|
|
}
|
|
|
|
// setHeaders attaches Azure authentication and content headers.
|
|
// Azure uses "api-key" instead of "Authorization: Bearer".
|
|
func (a *Adapter) setHeaders(req *http.Request) {
|
|
req.Header.Set("api-key", a.cfg.APIKey)
|
|
req.Header.Set("Content-Type", "application/json")
|
|
req.Header.Set("Accept", "application/json")
|
|
}
|
|
|
|
func mapUpstreamStatus(status int) *apierror.APIError {
|
|
switch status {
|
|
case http.StatusUnauthorized:
|
|
return apierror.NewAuthError("Azure OpenAI rejected the API key")
|
|
case http.StatusTooManyRequests:
|
|
return apierror.NewRateLimitError("Azure OpenAI rate limit reached")
|
|
case http.StatusBadRequest:
|
|
return apierror.NewBadRequestError("Azure OpenAI rejected the request")
|
|
case http.StatusGatewayTimeout, http.StatusRequestTimeout:
|
|
return apierror.NewTimeoutError("Azure OpenAI request timed out")
|
|
default:
|
|
return apierror.NewUpstreamError(fmt.Sprintf("Azure OpenAI returned HTTP %d", status))
|
|
}
|
|
}
|