Chore Support base64 embedding format (#485)
* chore: support base64 embedding format * fix: add sizeOfFloat32 * chore: refactor base64 decoding * chore: add tests * fix linting * fix test * fix return error * fix: use smaller slice for tests * fix [skip ci] * chore: refactor test to consider CreateEmbeddings response * trigger build * chore: remove named returns * chore: refactor code to simplify the understanding * chore: tests have been refactored to match the encoding format passed by request * chore: fix tests * fix * fix
This commit is contained in:
@@ -1,15 +1,16 @@
|
||||
package openai_test
|
||||
|
||||
import (
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
. "github.com/sashabaranov/go-openai"
|
||||
"github.com/sashabaranov/go-openai/internal/test/checks"
|
||||
)
|
||||
|
||||
func TestEmbedding(t *testing.T) {
|
||||
@@ -97,22 +98,138 @@ func TestEmbeddingModel(t *testing.T) {
|
||||
func TestEmbeddingEndpoint(t *testing.T) {
|
||||
client, server, teardown := setupOpenAITestServer()
|
||||
defer teardown()
|
||||
|
||||
sampleEmbeddings := []Embedding{
|
||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||
}
|
||||
|
||||
sampleBase64Embeddings := []Base64Embedding{
|
||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||
}
|
||||
|
||||
server.RegisterHandler(
|
||||
"/v1/embeddings",
|
||||
func(w http.ResponseWriter, r *http.Request) {
|
||||
resBytes, _ := json.Marshal(EmbeddingResponse{})
|
||||
var req struct {
|
||||
EncodingFormat EmbeddingEncodingFormat `json:"encoding_format"`
|
||||
User string `json:"user"`
|
||||
}
|
||||
_ = json.NewDecoder(r.Body).Decode(&req)
|
||||
|
||||
var resBytes []byte
|
||||
switch {
|
||||
case req.User == "invalid":
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
case req.EncodingFormat == EmbeddingEncodingFormatBase64:
|
||||
resBytes, _ = json.Marshal(EmbeddingResponseBase64{Data: sampleBase64Embeddings})
|
||||
default:
|
||||
resBytes, _ = json.Marshal(EmbeddingResponse{Data: sampleEmbeddings})
|
||||
}
|
||||
fmt.Fprintln(w, string(resBytes))
|
||||
},
|
||||
)
|
||||
// test create embeddings with strings (simple embedding request)
|
||||
_, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
||||
res, err := client.CreateEmbeddings(context.Background(), EmbeddingRequest{})
|
||||
checks.NoError(t, err, "CreateEmbeddings error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test create embeddings with strings (simple embedding request)
|
||||
res, err = client.CreateEmbeddings(
|
||||
context.Background(),
|
||||
EmbeddingRequest{
|
||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||
},
|
||||
)
|
||||
checks.NoError(t, err, "CreateEmbeddings error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test create embeddings with strings
|
||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestStrings{})
|
||||
checks.NoError(t, err, "CreateEmbeddings strings error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test create embeddings with tokens
|
||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
||||
res, err = client.CreateEmbeddings(context.Background(), EmbeddingRequestTokens{})
|
||||
checks.NoError(t, err, "CreateEmbeddings tokens error")
|
||||
if !reflect.DeepEqual(res.Data, sampleEmbeddings) {
|
||||
t.Errorf("Expected %#v embeddings, got %#v", sampleEmbeddings, res.Data)
|
||||
}
|
||||
|
||||
// test failed sendRequest
|
||||
_, err = client.CreateEmbeddings(context.Background(), EmbeddingRequest{
|
||||
User: "invalid",
|
||||
EncodingFormat: EmbeddingEncodingFormatBase64,
|
||||
})
|
||||
checks.HasError(t, err, "CreateEmbeddings error")
|
||||
}
|
||||
|
||||
func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
||||
type fields struct {
|
||||
Object string
|
||||
Data []Base64Embedding
|
||||
Model EmbeddingModel
|
||||
Usage Usage
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
want EmbeddingResponse
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "test embedding response base64 to embedding response",
|
||||
fields: fields{
|
||||
Data: []Base64Embedding{
|
||||
{Embedding: "pHCdP4XrkUDhevxA"},
|
||||
{Embedding: "/1jku0G/rLvA/EI8"},
|
||||
},
|
||||
},
|
||||
want: EmbeddingResponse{
|
||||
Data: []Embedding{
|
||||
{Embedding: []float32{1.23, 4.56, 7.89}},
|
||||
{Embedding: []float32{-0.006968617, -0.0052718227, 0.011901081}},
|
||||
},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "Invalid embedding",
|
||||
fields: fields{
|
||||
Data: []Base64Embedding{
|
||||
{
|
||||
Embedding: "----",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: EmbeddingResponse{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &EmbeddingResponseBase64{
|
||||
Object: tt.fields.Object,
|
||||
Data: tt.fields.Data,
|
||||
Model: tt.fields.Model,
|
||||
Usage: tt.fields.Usage,
|
||||
}
|
||||
got, err := r.ToEmbeddingResponse()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("EmbeddingResponseBase64.ToEmbeddingResponse() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user