@@ -1,9 +1,6 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@@ -12,13 +9,16 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestImages(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
|
||||
_, err := client.CreateImage(context.Background(), ImageRequest{
|
||||
_, err := client.CreateImage(context.Background(), openai.ImageRequest{
|
||||
Prompt: "Lorem ipsum",
|
||||
})
|
||||
checks.NoError(t, err, "CreateImage error")
|
||||
@@ -33,20 +33,20 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
var imageReq ImageRequest
|
||||
var imageReq openai.ImageRequest
|
||||
if imageReq, err = getImageBody(r); err != nil {
|
||||
http.Error(w, "could not read request", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
res := ImageResponse{
|
||||
res := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
for i := 0; i < imageReq.N; i++ {
|
||||
imageData := ImageResponseDataInner{}
|
||||
imageData := openai.ImageResponseDataInner{}
|
||||
switch imageReq.ResponseFormat {
|
||||
case CreateImageResponseFormatURL, "":
|
||||
case openai.CreateImageResponseFormatURL, "":
|
||||
imageData.URL = "https://example.com/image.png"
|
||||
case CreateImageResponseFormatB64JSON:
|
||||
case openai.CreateImageResponseFormatB64JSON:
|
||||
// This decodes to "{}" in base64.
|
||||
imageData.B64JSON = "e30K"
|
||||
default:
|
||||
@@ -60,16 +60,16 @@ func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
// getImageBody Returns the body of the request to create a image.
|
||||
func getImageBody(r *http.Request) (ImageRequest, error) {
|
||||
image := ImageRequest{}
|
||||
func getImageBody(r *http.Request) (openai.ImageRequest, error) {
|
||||
image := openai.ImageRequest{}
|
||||
// read the request body
|
||||
reqBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
return ImageRequest{}, err
|
||||
return openai.ImageRequest{}, err
|
||||
}
|
||||
err = json.Unmarshal(reqBody, &image)
|
||||
if err != nil {
|
||||
return ImageRequest{}, err
|
||||
return openai.ImageRequest{}, err
|
||||
}
|
||||
return image, nil
|
||||
}
|
||||
@@ -98,13 +98,13 @@ func TestImageEdit(t *testing.T) {
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
|
||||
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||
Image: origin,
|
||||
Mask: mask,
|
||||
Prompt: "There is a turtle in the pool",
|
||||
N: 3,
|
||||
Size: CreateImageSize1024x1024,
|
||||
ResponseFormat: CreateImageResponseFormatURL,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
})
|
||||
checks.NoError(t, err, "CreateImage error")
|
||||
}
|
||||
@@ -125,12 +125,12 @@ func TestImageEditWithoutMask(t *testing.T) {
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
|
||||
_, err = client.CreateEditImage(context.Background(), ImageEditRequest{
|
||||
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
|
||||
Image: origin,
|
||||
Prompt: "There is a turtle in the pool",
|
||||
N: 3,
|
||||
Size: CreateImageSize1024x1024,
|
||||
ResponseFormat: CreateImageResponseFormatURL,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
})
|
||||
checks.NoError(t, err, "CreateImage error")
|
||||
}
|
||||
@@ -144,9 +144,9 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
responses := ImageResponse{
|
||||
responses := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: []ImageResponseDataInner{
|
||||
Data: []openai.ImageResponseDataInner{
|
||||
{
|
||||
URL: "test-url1",
|
||||
B64JSON: "",
|
||||
@@ -182,11 +182,11 @@ func TestImageVariation(t *testing.T) {
|
||||
os.Remove("image.png")
|
||||
}()
|
||||
|
||||
_, err = client.CreateVariImage(context.Background(), ImageVariRequest{
|
||||
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
|
||||
Image: origin,
|
||||
N: 3,
|
||||
Size: CreateImageSize1024x1024,
|
||||
ResponseFormat: CreateImageResponseFormatURL,
|
||||
Size: openai.CreateImageSize1024x1024,
|
||||
ResponseFormat: openai.CreateImageResponseFormatURL,
|
||||
})
|
||||
checks.NoError(t, err, "CreateImage error")
|
||||
}
|
||||
@@ -200,9 +200,9 @@ func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
|
||||
responses := ImageResponse{
|
||||
responses := openai.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
Data: []ImageResponseDataInner{
|
||||
Data: []openai.ImageResponseDataInner{
|
||||
{
|
||||
URL: "test-url1",
|
||||
B64JSON: "",
|
||||
|
||||
Reference in New Issue
Block a user