convert EmbeddingModel to string type (#629)

This gives the user the ability to pass in models for embeddings that are not
already defined in the library. Also more closely matches how the completions
API works.
This commit is contained in:
Matthew Jaffee
2024-01-15 03:33:02 -06:00
committed by GitHub
parent 682b7adb0b
commit e01a2d7231
2 changed files with 24 additions and 118 deletions

View File

@@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch")
// EmbeddingModel enumerates the models which can be used
// to generate Embedding vectors.
type EmbeddingModel int
// String implements the fmt.Stringer interface.
func (e EmbeddingModel) String() string {
return enumToString[e]
}
// MarshalText implements the encoding.TextMarshaler interface.
func (e EmbeddingModel) MarshalText() ([]byte, error) {
return []byte(e.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface.
// On unrecognized value, it sets |e| to Unknown.
func (e *EmbeddingModel) UnmarshalText(b []byte) error {
if val, ok := stringToEnum[(string(b))]; ok {
*e = val
return nil
}
*e = Unknown
return nil
}
type EmbeddingModel string
const (
Unknown EmbeddingModel = iota
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSimilarity
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
CurieSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchDocument
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
DavinciSearchQuery
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaCodeSearchText
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchCode
// Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
BabbageCodeSearchText
AdaEmbeddingV2
// Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead.
AdaSimilarity EmbeddingModel = "text-similarity-ada-001"
BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001"
CurieSimilarity EmbeddingModel = "text-similarity-curie-001"
DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001"
AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001"
AdaSearchQuery EmbeddingModel = "text-search-ada-query-001"
BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001"
BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001"
CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001"
CurieSearchQuery EmbeddingModel = "text-search-curie-query-001"
DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001"
DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001"
AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001"
AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001"
BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001"
BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001"
AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002"
)
var enumToString = map[EmbeddingModel]string{
AdaSimilarity: "text-similarity-ada-001",
BabbageSimilarity: "text-similarity-babbage-001",
CurieSimilarity: "text-similarity-curie-001",
DavinciSimilarity: "text-similarity-davinci-001",
AdaSearchDocument: "text-search-ada-doc-001",
AdaSearchQuery: "text-search-ada-query-001",
BabbageSearchDocument: "text-search-babbage-doc-001",
BabbageSearchQuery: "text-search-babbage-query-001",
CurieSearchDocument: "text-search-curie-doc-001",
CurieSearchQuery: "text-search-curie-query-001",
DavinciSearchDocument: "text-search-davinci-doc-001",
DavinciSearchQuery: "text-search-davinci-query-001",
AdaCodeSearchCode: "code-search-ada-code-001",
AdaCodeSearchText: "code-search-ada-text-001",
BabbageCodeSearchCode: "code-search-babbage-code-001",
BabbageCodeSearchText: "code-search-babbage-text-001",
AdaEmbeddingV2: "text-embedding-ada-002",
}
var stringToEnum = map[string]EmbeddingModel{
"text-similarity-ada-001": AdaSimilarity,
"text-similarity-babbage-001": BabbageSimilarity,
"text-similarity-curie-001": CurieSimilarity,
"text-similarity-davinci-001": DavinciSimilarity,
"text-search-ada-doc-001": AdaSearchDocument,
"text-search-ada-query-001": AdaSearchQuery,
"text-search-babbage-doc-001": BabbageSearchDocument,
"text-search-babbage-query-001": BabbageSearchQuery,
"text-search-curie-doc-001": CurieSearchDocument,
"text-search-curie-query-001": CurieSearchQuery,
"text-search-davinci-doc-001": DavinciSearchDocument,
"text-search-davinci-query-001": DavinciSearchQuery,
"code-search-ada-code-001": AdaCodeSearchCode,
"code-search-ada-text-001": AdaCodeSearchText,
"code-search-babbage-code-001": BabbageCodeSearchCode,
"code-search-babbage-text-001": BabbageCodeSearchText,
"text-embedding-ada-002": AdaEmbeddingV2,
}
// Embedding is a special format of data representation that can be easily utilized by machine
// learning models and algorithms. The embedding is an information dense representation of the
// semantic meaning of a piece of text. Each embedding is a vector of floating point numbers,
@@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings(
conv EmbeddingRequestConverter,
) (res EmbeddingResponse, err error) {
baseReq := conv.Convert()
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq))
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq))
if err != nil {
return
}

View File

@@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) {
// the AdaSearchQuery type
marshaled, err := json.Marshal(embeddingReq)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
@@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqStrings)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
@@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) {
}
marshaled, err = json.Marshal(embeddingReqTokens)
checks.NoError(t, err, "Could not marshal embedding request")
if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) {
if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) {
t.Fatalf("Expected embedding request to contain model field")
}
}
}
func TestEmbeddingModel(t *testing.T) {
var em openai.EmbeddingModel
err := em.UnmarshalText([]byte("text-similarity-ada-001"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.AdaSimilarity {
t.Errorf("Model is not equal to AdaSimilarity")
}
err = em.UnmarshalText([]byte("some-non-existent-model"))
checks.NoError(t, err, "Could not marshal embedding model")
if em != openai.Unknown {
t.Errorf("Model is not equal to Unknown")
}
}
func TestEmbeddingEndpoint(t *testing.T) {
client, server, teardown := setupOpenAITestServer()
defer teardown()