Add Image generation API (#48)

This commit is contained in:
Andrew Poydence
2023-01-03 06:15:50 -07:00
committed by GitHub
parent 1c20931ead
commit 2c55a49a34
2 changed files with 132 additions and 0 deletions

View File

@@ -129,6 +129,7 @@ func TestEdits(t *testing.T) {
t.Fatalf("edits does not properly return the correct number of choices")
}
}
func TestEmbedding(t *testing.T) {
embeddedModels := []EmbeddingModel{
AdaSimilarity,
@@ -269,6 +270,41 @@ func handleCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, string(resBytes))
}
// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// imagess only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := ImageResponse{
Created: uint64(time.Now().Unix()),
}
for i := 0; i < imageReq.N; i++ {
imageData := ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getCompletionBody Returns the body of the request to create a completion.
func getCompletionBody(r *http.Request) (CompletionRequest, error) {
completion := CompletionRequest{}
@@ -284,6 +320,21 @@ func getCompletionBody(r *http.Request) (CompletionRequest, error) {
return completion, nil
}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (ImageRequest, error) {
image := ImageRequest{}
// read the request body
reqBody, err := ioutil.ReadAll(r.Body)
if err != nil {
return ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return ImageRequest{}, err
}
return image, nil
}
// numTokens Returns the number of GPT-3 encoded tokens in the given text.
// This function approximates based on the rule of thumb stated by OpenAI:
// https://beta.openai.com/tokenizer
@@ -293,6 +344,25 @@ func numTokens(s string) int {
return int(float32(len(s)) / 4)
}
func TestImages(t *testing.T) {
// create the test server
var err error
ts := OpenAITestServer()
ts.Start()
defer ts.Close()
client := NewClient(testAPIToken)
ctx := context.Background()
client.BaseURL = ts.URL + "/v1"
req := ImageRequest{}
req.Prompt = "Lorem ipsum"
_, err = client.CreateImage(ctx, req)
if err != nil {
t.Fatalf("CreateImage error: %v", err)
}
}
// OpenAITestServer Creates a mocked OpenAI server which can pretend to handle requests during testing.
func OpenAITestServer() *httptest.Server {
return httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -312,6 +382,8 @@ func OpenAITestServer() *httptest.Server {
case "/v1/completions":
handleCompletionEndpoint(w, r)
return
case "/v1/images/generations":
handleImageEndpoint(w, r)
// TODO: implement the other endpoints
default:
// the endpoint doesn't exist