diff --git a/api_test.go b/api_test.go index c5401f2..30933b3 100644 --- a/api_test.go +++ b/api_test.go @@ -129,6 +129,7 @@ func TestEdits(t *testing.T) { t.Fatalf("edits does not properly return the correct number of choices") } } + func TestEmbedding(t *testing.T) { embeddedModels := []EmbeddingModel{ AdaSimilarity, @@ -269,6 +270,41 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, string(resBytes)) } +// handleImageEndpoint Handles the images endpoint by the test server. +func handleImageEndpoint(w http.ResponseWriter, r *http.Request) { + var err error + var resBytes []byte + + // imagess only accepts POST requests + if r.Method != "POST" { + http.Error(w, "Method not allowed", http.StatusMethodNotAllowed) + } + var imageReq ImageRequest + if imageReq, err = getImageBody(r); err != nil { + http.Error(w, "could not read request", http.StatusInternalServerError) + return + } + res := ImageResponse{ + Created: uint64(time.Now().Unix()), + } + for i := 0; i < imageReq.N; i++ { + imageData := ImageResponseDataInner{} + switch imageReq.ResponseFormat { + case CreateImageResponseFormatURL, "": + imageData.URL = "https://example.com/image.png" + case CreateImageResponseFormatB64JSON: + // This decodes to "{}" in base64. + imageData.B64JSON = "e30K" + default: + http.Error(w, "invalid response format", http.StatusBadRequest) + return + } + res.Data = append(res.Data, imageData) + } + resBytes, _ = json.Marshal(res) + fmt.Fprintln(w, string(resBytes)) +} + // getCompletionBody Returns the body of the request to create a completion. func getCompletionBody(r *http.Request) (CompletionRequest, error) { completion := CompletionRequest{} @@ -284,6 +320,21 @@ func getCompletionBody(r *http.Request) (CompletionRequest, error) { return completion, nil } +// getImageBody Returns the body of the request to create a image. +func getImageBody(r *http.Request) (ImageRequest, error) { + image := ImageRequest{} + // read the request body + reqBody, err := ioutil.ReadAll(r.Body) + if err != nil { + return ImageRequest{}, err + } + err = json.Unmarshal(reqBody, &image) + if err != nil { + return ImageRequest{}, err + } + return image, nil +} + // numTokens Returns the number of GPT-3 encoded tokens in the given text. // This function approximates based on the rule of thumb stated by OpenAI: // https://beta.openai.com/tokenizer @@ -293,6 +344,25 @@ func numTokens(s string) int { return int(float32(len(s)) / 4) } +func TestImages(t *testing.T) { + // create the test server + var err error + ts := OpenAITestServer() + ts.Start() + defer ts.Close() + + client := NewClient(testAPIToken) + ctx := context.Background() + client.BaseURL = ts.URL + "/v1" + + req := ImageRequest{} + req.Prompt = "Lorem ipsum" + _, err = client.CreateImage(ctx, req) + if err != nil { + t.Fatalf("CreateImage error: %v", err) + } +} + // OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing. func OpenAITestServer() *httptest.Server { return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -312,6 +382,8 @@ func OpenAITestServer() *httptest.Server { case "/v1/completions": handleCompletionEndpoint(w, r) return + case "/v1/images/generations": + handleImageEndpoint(w, r) // TODO: implement the other endpoints default: // the endpoint doesn't exist diff --git a/image.go b/image.go new file mode 100644 index 0000000..335e82f --- /dev/null +++ b/image.go @@ -0,0 +1,60 @@ +package gogpt + +import ( + "bytes" + "context" + "encoding/json" + "net/http" +) + +// Image sizes defined by the OpenAI API. +const ( + CreateImageSize256x256 = "256x256" + CreateImageSize512x512 = "512x512" + CreateImageSize1024x1024 = "1024x1024" +) + +const ( + CreateImageResponseFormatURL = "url" + CreateImageResponseFormatB64JSON = "b64_json" +) + +// ImageRequest represents the request structure for the image API. +type ImageRequest struct { + Prompt string `json:"prompt,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + User string `json:"user,omitempty"` +} + +// ImageResponse represents a response structure for image API. +type ImageResponse struct { + Created uint64 `json:"created,omitempty"` + Data []ImageResponseDataInner `json:"data,omitempty"` +} + +// ImageResponseData represents a response data structure for image API. +type ImageResponseDataInner struct { + URL string `json:"url,omitempty"` + B64JSON string `json:"b64_json,omitempty"` +} + +// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API. +func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) { + var reqBytes []byte + reqBytes, err = json.Marshal(request) + if err != nil { + return + } + + urlSuffix := "/images/generations" + req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes)) + if err != nil { + return + } + + req = req.WithContext(ctx) + err = c.sendRequest(req, &response) + return +}