Files
go-openai/image_api_test.go
2025-06-15 12:58:45 +00:00

215 lines
5.7 KiB
Go

package openai_test
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"git.vaala.cloud/VaalaCat/go-openai"
"git.vaala.cloud/VaalaCat/go-openai/internal/test/checks"
)
func TestImages(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/generations", handleImageEndpoint)
_, err := client.CreateImage(context.Background(), openai.ImageRequest{
Prompt: "Lorem ipsum",
Model: openai.CreateImageModelDallE3,
N: 1,
Quality: openai.CreateImageQualityHD,
Size: openai.CreateImageSize1024x1024,
Style: openai.CreateImageStyleVivid,
ResponseFormat: openai.CreateImageResponseFormatURL,
User: "user",
})
checks.NoError(t, err, "CreateImage error")
}
// handleImageEndpoint Handles the images endpoint by the test server.
func handleImageEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
var resBytes []byte
// images only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
var imageReq openai.ImageRequest
if imageReq, err = getImageBody(r); err != nil {
http.Error(w, "could not read request", http.StatusInternalServerError)
return
}
res := openai.ImageResponse{
Created: time.Now().Unix(),
}
for i := 0; i < imageReq.N; i++ {
imageData := openai.ImageResponseDataInner{}
switch imageReq.ResponseFormat {
case openai.CreateImageResponseFormatURL, "":
imageData.URL = "https://example.com/image.png"
case openai.CreateImageResponseFormatB64JSON:
// This decodes to "{}" in base64.
imageData.B64JSON = "e30K"
default:
http.Error(w, "invalid response format", http.StatusBadRequest)
return
}
res.Data = append(res.Data, imageData)
}
resBytes, _ = json.Marshal(res)
fmt.Fprintln(w, string(resBytes))
}
// getImageBody Returns the body of the request to create a image.
func getImageBody(r *http.Request) (openai.ImageRequest, error) {
image := openai.ImageRequest{}
// read the request body
reqBody, err := io.ReadAll(r.Body)
if err != nil {
return openai.ImageRequest{}, err
}
err = json.Unmarshal(reqBody, &image)
if err != nil {
return openai.ImageRequest{}, err
}
return image, nil
}
func TestImageEdit(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Fatalf("open origin file error: %v", err)
}
defer origin.Close()
mask, err := os.Create(filepath.Join(t.TempDir(), "mask.png"))
if err != nil {
t.Fatalf("open mask file error: %v", err)
}
defer mask.Close()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin,
Mask: mask,
Prompt: "There is a turtle in the pool",
N: 3,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
func TestImageEditWithoutMask(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Fatalf("open origin file error: %v", err)
}
defer origin.Close()
_, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{
Image: origin,
Prompt: "There is a turtle in the pool",
N: 3,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
// handleEditImageEndpoint Handles the images endpoint by the test server.
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// images only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := openai.ImageResponse{
Created: time.Now().Unix(),
Data: []openai.ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}
func TestImageVariation(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
origin, err := os.Create(filepath.Join(t.TempDir(), "image.png"))
if err != nil {
t.Fatalf("open origin file error: %v", err)
}
defer origin.Close()
_, err = client.CreateVariImage(context.Background(), openai.ImageVariRequest{
Image: origin,
N: 3,
Size: openai.CreateImageSize1024x1024,
ResponseFormat: openai.CreateImageResponseFormatURL,
})
checks.NoError(t, err, "CreateImage error")
}
// handleVariateImageEndpoint Handles the images endpoint by the test server.
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
var resBytes []byte
// images only accepts POST requests
if r.Method != "POST" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
}
responses := openai.ImageResponse{
Created: time.Now().Unix(),
Data: []openai.ImageResponseDataInner{
{
URL: "test-url1",
B64JSON: "",
},
{
URL: "test-url2",
B64JSON: "",
},
{
URL: "test-url3",
B64JSON: "",
},
},
}
resBytes, _ = json.Marshal(responses)
fmt.Fprintln(w, string(resBytes))
}