diff --git a/client.go b/client.go index 7b1a313..9a1c895 100644 --- a/client.go +++ b/client.go @@ -38,6 +38,12 @@ func (h *httpHeader) GetRateLimitHeaders() RateLimitHeaders { return newRateLimitHeaders(h.Header()) } +type RawResponse struct { + io.ReadCloser + + httpHeader +} + // NewClient creates new OpenAI API client. func NewClient(authToken string) *Client { config := DefaultConfig(authToken) @@ -134,8 +140,8 @@ func (c *Client) sendRequest(req *http.Request, v Response) error { return decodeResponse(res.Body, v) } -func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err error) { - resp, err := c.config.HTTPClient.Do(req) +func (c *Client) sendRequestRaw(req *http.Request) (response RawResponse, err error) { + resp, err := c.config.HTTPClient.Do(req) //nolint:bodyclose // body should be closed by outer function if err != nil { return } @@ -144,7 +150,10 @@ func (c *Client) sendRequestRaw(req *http.Request) (body io.ReadCloser, err erro err = c.handleErrorResp(resp) return } - return resp.Body, nil + + response.SetHeader(resp.Header) + response.ReadCloser = resp.Body + return } func sendRequestStream[T streamable](client *Client, req *http.Request) (*streamReader[T], error) { diff --git a/files.go b/files.go index a37d45f..b40a44f 100644 --- a/files.go +++ b/files.go @@ -4,7 +4,6 @@ import ( "bytes" "context" "fmt" - "io" "net/http" "os" ) @@ -159,13 +158,12 @@ func (c *Client) GetFile(ctx context.Context, fileID string) (file File, err err return } -func (c *Client) GetFileContent(ctx context.Context, fileID string) (content io.ReadCloser, err error) { +func (c *Client) GetFileContent(ctx context.Context, fileID string) (content RawResponse, err error) { urlSuffix := fmt.Sprintf("/files/%s/content", fileID) req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix)) if err != nil { return } - content, err = c.sendRequestRaw(req) - return + return c.sendRequestRaw(req) } diff --git a/speech.go b/speech.go index 92b30b5..7e22e75 100644 --- a/speech.go +++ b/speech.go @@ -3,7 +3,6 @@ package openai import ( "context" "errors" - "io" "net/http" ) @@ -67,7 +66,7 @@ func isValidVoice(voice SpeechVoice) bool { return contains([]SpeechVoice{VoiceAlloy, VoiceEcho, VoiceFable, VoiceOnyx, VoiceNova, VoiceShimmer}, voice) } -func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response io.ReadCloser, err error) { +func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) (response RawResponse, err error) { if !isValidSpeechModel(request.Model) { err = ErrInvalidSpeechModel return @@ -84,7 +83,5 @@ func (c *Client) CreateSpeech(ctx context.Context, request CreateSpeechRequest) return } - response, err = c.sendRequestRaw(req) - - return + return c.sendRequestRaw(req) }