Add Image generation API (#48)
This commit is contained in:
72
api_test.go
72
api_test.go
@@ -129,6 +129,7 @@ func TestEdits(t *testing.T) {
|
|||||||
t.Fatalf("edits does not properly return the correct number of choices")
|
t.Fatalf("edits does not properly return the correct number of choices")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestEmbedding(t *testing.T) {
|
func TestEmbedding(t *testing.T) {
|
||||||
embeddedModels := []EmbeddingModel{
|
embeddedModels := []EmbeddingModel{
|
||||||
AdaSimilarity,
|
AdaSimilarity,
|
||||||
@@ -269,6 +270,41 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
|
|||||||
fmt.Fprintln(w, string(resBytes))
|
fmt.Fprintln(w, string(resBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleImageEndpoint Handles the images endpoint by the test server.
|
||||||
|
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var err error
|
||||||
|
var resBytes []byte
|
||||||
|
|
||||||
|
// imagess only accepts POST requests
|
||||||
|
if r.Method != "POST" {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
}
|
||||||
|
var imageReq ImageRequest
|
||||||
|
if imageReq, err = getImageBody(r); err != nil {
|
||||||
|
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res := ImageResponse{
|
||||||
|
Created: uint64(time.Now().Unix()),
|
||||||
|
}
|
||||||
|
for i := 0; i < imageReq.N; i++ {
|
||||||
|
imageData := ImageResponseDataInner{}
|
||||||
|
switch imageReq.ResponseFormat {
|
||||||
|
case CreateImageResponseFormatURL, "":
|
||||||
|
imageData.URL = "https://example.com/image.png"
|
||||||
|
case CreateImageResponseFormatB64JSON:
|
||||||
|
// This decodes to "{}" in base64.
|
||||||
|
imageData.B64JSON = "e30K"
|
||||||
|
default:
|
||||||
|
http.Error(w, "invalid response format", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
res.Data = append(res.Data, imageData)
|
||||||
|
}
|
||||||
|
resBytes, _ = json.Marshal(res)
|
||||||
|
fmt.Fprintln(w, string(resBytes))
|
||||||
|
}
|
||||||
|
|
||||||
// getCompletionBody Returns the body of the request to create a completion.
|
// getCompletionBody Returns the body of the request to create a completion.
|
||||||
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
|
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
|
||||||
completion := CompletionRequest{}
|
completion := CompletionRequest{}
|
||||||
@@ -284,6 +320,21 @@ func getCompletionBody(r *http.Request) (CompletionRequest, error) {
|
|||||||
return completion, nil
|
return completion, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getImageBody Returns the body of the request to create a image.
|
||||||
|
func getImageBody(r *http.Request) (ImageRequest, error) {
|
||||||
|
image := ImageRequest{}
|
||||||
|
// read the request body
|
||||||
|
reqBody, err := ioutil.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRequest{}, err
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(reqBody, &image)
|
||||||
|
if err != nil {
|
||||||
|
return ImageRequest{}, err
|
||||||
|
}
|
||||||
|
return image, nil
|
||||||
|
}
|
||||||
|
|
||||||
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
|
||||||
// This function approximates based on the rule of thumb stated by OpenAI:
|
// This function approximates based on the rule of thumb stated by OpenAI:
|
||||||
// https://beta.openai.com/tokenizer
|
// https://beta.openai.com/tokenizer
|
||||||
@@ -293,6 +344,25 @@ func numTokens(s string) int {
|
|||||||
return int(float32(len(s)) / 4)
|
return int(float32(len(s)) / 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestImages(t *testing.T) {
|
||||||
|
// create the test server
|
||||||
|
var err error
|
||||||
|
ts := OpenAITestServer()
|
||||||
|
ts.Start()
|
||||||
|
defer ts.Close()
|
||||||
|
|
||||||
|
client := NewClient(testAPIToken)
|
||||||
|
ctx := context.Background()
|
||||||
|
client.BaseURL = ts.URL + "/v1"
|
||||||
|
|
||||||
|
req := ImageRequest{}
|
||||||
|
req.Prompt = "Lorem ipsum"
|
||||||
|
_, err = client.CreateImage(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreateImage error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
|
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
|
||||||
func OpenAITestServer() *httptest.Server {
|
func OpenAITestServer() *httptest.Server {
|
||||||
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
@@ -312,6 +382,8 @@ func OpenAITestServer() *httptest.Server {
|
|||||||
case "/v1/completions":
|
case "/v1/completions":
|
||||||
handleCompletionEndpoint(w, r)
|
handleCompletionEndpoint(w, r)
|
||||||
return
|
return
|
||||||
|
case "/v1/images/generations":
|
||||||
|
handleImageEndpoint(w, r)
|
||||||
// TODO: implement the other endpoints
|
// TODO: implement the other endpoints
|
||||||
default:
|
default:
|
||||||
// the endpoint doesn't exist
|
// the endpoint doesn't exist
|
||||||
|
|||||||
60
image.go
Normal file
60
image.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package gogpt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Image sizes defined by the OpenAI API.
|
||||||
|
const (
|
||||||
|
CreateImageSize256x256 = "256x256"
|
||||||
|
CreateImageSize512x512 = "512x512"
|
||||||
|
CreateImageSize1024x1024 = "1024x1024"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CreateImageResponseFormatURL = "url"
|
||||||
|
CreateImageResponseFormatB64JSON = "b64_json"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ImageRequest represents the request structure for the image API.
|
||||||
|
type ImageRequest struct {
|
||||||
|
Prompt string `json:"prompt,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
Size string `json:"size,omitempty"`
|
||||||
|
ResponseFormat string `json:"response_format,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponse represents a response structure for image API.
|
||||||
|
type ImageResponse struct {
|
||||||
|
Created uint64 `json:"created,omitempty"`
|
||||||
|
Data []ImageResponseDataInner `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ImageResponseData represents a response data structure for image API.
|
||||||
|
type ImageResponseDataInner struct {
|
||||||
|
URL string `json:"url,omitempty"`
|
||||||
|
B64JSON string `json:"b64_json,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateImage - API call to create an image. This is the main endpoint of the DALL-E API.
|
||||||
|
func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (response ImageResponse, err error) {
|
||||||
|
var reqBytes []byte
|
||||||
|
reqBytes, err = json.Marshal(request)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
urlSuffix := "/images/generations"
|
||||||
|
req, err := http.NewRequest(http.MethodPost, c.fullURL(urlSuffix), bytes.NewBuffer(reqBytes))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req = req.WithContext(ctx)
|
||||||
|
err = c.sendRequest(req, &response)
|
||||||
|
return
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user