From e01a2d7231fafec2c1cbdd176806e3be767df965 Mon Sep 17 00:00:00 2001 From: Matthew Jaffee Date: Mon, 15 Jan 2024 03:33:02 -0600 Subject: [PATCH] convert EmbeddingModel to string type (#629) This gives the user the ability to pass in models for embeddings that are not already defined in the library. Also more closely matches how the completions API works. --- embeddings.go | 120 ++++++++------------------------------------- embeddings_test.go | 22 ++------- 2 files changed, 24 insertions(+), 118 deletions(-) diff --git a/embeddings.go b/embeddings.go index 7e2aa7e..f79df9d 100644 --- a/embeddings.go +++ b/embeddings.go @@ -13,108 +13,30 @@ var ErrVectorLengthMismatch = errors.New("vector length mismatch") // EmbeddingModel enumerates the models which can be used // to generate Embedding vectors. -type EmbeddingModel int - -// String implements the fmt.Stringer interface. -func (e EmbeddingModel) String() string { - return enumToString[e] -} - -// MarshalText implements the encoding.TextMarshaler interface. -func (e EmbeddingModel) MarshalText() ([]byte, error) { - return []byte(e.String()), nil -} - -// UnmarshalText implements the encoding.TextUnmarshaler interface. -// On unrecognized value, it sets |e| to Unknown. -func (e *EmbeddingModel) UnmarshalText(b []byte) error { - if val, ok := stringToEnum[(string(b))]; ok { - *e = val - return nil - } - - *e = Unknown - - return nil -} +type EmbeddingModel string const ( - Unknown EmbeddingModel = iota - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSimilarity - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - CurieSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchDocument - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - DavinciSearchQuery - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - AdaCodeSearchText - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchCode - // Deprecated: Will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. - BabbageCodeSearchText - AdaEmbeddingV2 + // Deprecated: The following block will be shut down on January 04, 2024. Use text-embedding-ada-002 instead. + AdaSimilarity EmbeddingModel = "text-similarity-ada-001" + BabbageSimilarity EmbeddingModel = "text-similarity-babbage-001" + CurieSimilarity EmbeddingModel = "text-similarity-curie-001" + DavinciSimilarity EmbeddingModel = "text-similarity-davinci-001" + AdaSearchDocument EmbeddingModel = "text-search-ada-doc-001" + AdaSearchQuery EmbeddingModel = "text-search-ada-query-001" + BabbageSearchDocument EmbeddingModel = "text-search-babbage-doc-001" + BabbageSearchQuery EmbeddingModel = "text-search-babbage-query-001" + CurieSearchDocument EmbeddingModel = "text-search-curie-doc-001" + CurieSearchQuery EmbeddingModel = "text-search-curie-query-001" + DavinciSearchDocument EmbeddingModel = "text-search-davinci-doc-001" + DavinciSearchQuery EmbeddingModel = "text-search-davinci-query-001" + AdaCodeSearchCode EmbeddingModel = "code-search-ada-code-001" + AdaCodeSearchText EmbeddingModel = "code-search-ada-text-001" + BabbageCodeSearchCode EmbeddingModel = "code-search-babbage-code-001" + BabbageCodeSearchText EmbeddingModel = "code-search-babbage-text-001" + + AdaEmbeddingV2 EmbeddingModel = "text-embedding-ada-002" ) -var enumToString = map[EmbeddingModel]string{ - AdaSimilarity: "text-similarity-ada-001", - BabbageSimilarity: "text-similarity-babbage-001", - CurieSimilarity: "text-similarity-curie-001", - DavinciSimilarity: "text-similarity-davinci-001", - AdaSearchDocument: "text-search-ada-doc-001", - AdaSearchQuery: "text-search-ada-query-001", - BabbageSearchDocument: "text-search-babbage-doc-001", - BabbageSearchQuery: "text-search-babbage-query-001", - CurieSearchDocument: "text-search-curie-doc-001", - CurieSearchQuery: "text-search-curie-query-001", - DavinciSearchDocument: "text-search-davinci-doc-001", - DavinciSearchQuery: "text-search-davinci-query-001", - AdaCodeSearchCode: "code-search-ada-code-001", - AdaCodeSearchText: "code-search-ada-text-001", - BabbageCodeSearchCode: "code-search-babbage-code-001", - BabbageCodeSearchText: "code-search-babbage-text-001", - AdaEmbeddingV2: "text-embedding-ada-002", -} - -var stringToEnum = map[string]EmbeddingModel{ - "text-similarity-ada-001": AdaSimilarity, - "text-similarity-babbage-001": BabbageSimilarity, - "text-similarity-curie-001": CurieSimilarity, - "text-similarity-davinci-001": DavinciSimilarity, - "text-search-ada-doc-001": AdaSearchDocument, - "text-search-ada-query-001": AdaSearchQuery, - "text-search-babbage-doc-001": BabbageSearchDocument, - "text-search-babbage-query-001": BabbageSearchQuery, - "text-search-curie-doc-001": CurieSearchDocument, - "text-search-curie-query-001": CurieSearchQuery, - "text-search-davinci-doc-001": DavinciSearchDocument, - "text-search-davinci-query-001": DavinciSearchQuery, - "code-search-ada-code-001": AdaCodeSearchCode, - "code-search-ada-text-001": AdaCodeSearchText, - "code-search-babbage-code-001": BabbageCodeSearchCode, - "code-search-babbage-text-001": BabbageCodeSearchText, - "text-embedding-ada-002": AdaEmbeddingV2, -} - // Embedding is a special format of data representation that can be easily utilized by machine // learning models and algorithms. The embedding is an information dense representation of the // semantic meaning of a piece of text. Each embedding is a vector of floating point numbers, @@ -306,7 +228,7 @@ func (c *Client) CreateEmbeddings( conv EmbeddingRequestConverter, ) (res EmbeddingResponse, err error) { baseReq := conv.Convert() - req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model.String()), withBody(baseReq)) + req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/embeddings", baseReq.Model), withBody(baseReq)) if err != nil { return } diff --git a/embeddings_test.go b/embeddings_test.go index af04d96..846d199 100644 --- a/embeddings_test.go +++ b/embeddings_test.go @@ -47,7 +47,7 @@ func TestEmbedding(t *testing.T) { // the AdaSearchQuery type marshaled, err := json.Marshal(embeddingReq) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -61,7 +61,7 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqStrings) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } @@ -75,28 +75,12 @@ func TestEmbedding(t *testing.T) { } marshaled, err = json.Marshal(embeddingReqTokens) checks.NoError(t, err, "Could not marshal embedding request") - if !bytes.Contains(marshaled, []byte(`"model":"`+model.String()+`"`)) { + if !bytes.Contains(marshaled, []byte(`"model":"`+model+`"`)) { t.Fatalf("Expected embedding request to contain model field") } } } -func TestEmbeddingModel(t *testing.T) { - var em openai.EmbeddingModel - err := em.UnmarshalText([]byte("text-similarity-ada-001")) - checks.NoError(t, err, "Could not marshal embedding model") - - if em != openai.AdaSimilarity { - t.Errorf("Model is not equal to AdaSimilarity") - } - - err = em.UnmarshalText([]byte("some-non-existent-model")) - checks.NoError(t, err, "Could not marshal embedding model") - if em != openai.Unknown { - t.Errorf("Model is not equal to Unknown") - } -} - func TestEmbeddingEndpoint(t *testing.T) { client, server, teardown := setupOpenAITestServer() defer teardown()