convert EmbeddingModel to string type (#629)

This gives the user the ability to pass in models for embeddings that are not
already defined in the library. Also more closely matches how the completions
API works.
This commit is contained in:
Matthew Jaffee
2024-01-15 03:33:02 -06:00
committed by GitHub
parent 682b7adb0b
commit e01a2d7231
2 changed files with 24 additions and 118 deletions

View File

@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqStrings)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqTokens)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}
func TestEmbeddingModel(t *testing.T) {
var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}
err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown")
}
}
func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()