Add DotProduct Method and README Example for Embedding Similarity Search (#492)
* Add DotProduct Method and README Example for Embedding Similarity Search - Implement a DotProduct() method for the Embedding struct to calculate the dot product between two embeddings. - Add a custom error type for vector length mismatch. - Update README.md with a complete example demonstrating how to perform an embedding similarity search for user queries. - Add unit tests to validate the new DotProduct() method and error handling. * Update README to focus on Embedding Semantic Similarity
This commit is contained in:
56
README.md
56
README.md
@@ -483,6 +483,62 @@ func main() {
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<detail>
|
||||||
|
<summary>Embedding Semantic Similarity</summary>
|
||||||
|
|
||||||
|
```go
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"log"
|
||||||
|
openai "github.com/sashabaranov/go-openai"
|
||||||
|
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
client := openai.NewClient("your-token")
|
||||||
|
|
||||||
|
// Create an EmbeddingRequest for the user query
|
||||||
|
queryReq := openai.EmbeddingRequest{
|
||||||
|
Input: []string{"How many chucks would a woodchuck chuck"},
|
||||||
|
Model: openai.AdaEmbeddingv2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an embedding for the user query
|
||||||
|
queryResponse, err := client.CreateEmbeddings(context.Background(), queryReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Error creating query embedding:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an EmbeddingRequest for the target text
|
||||||
|
targetReq := openai.EmbeddingRequest{
|
||||||
|
Input: []string{"How many chucks would a woodchuck chuck if the woodchuck could chuck wood"},
|
||||||
|
Model: openai.AdaEmbeddingv2,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create an embedding for the target text
|
||||||
|
targetResponse, err := client.CreateEmbeddings(context.Background(), targetReq)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Error creating target embedding:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now that we have the embeddings for the user query and the target text, we
|
||||||
|
// can calculate their similarity.
|
||||||
|
queryEmbedding := queryResponse.Data[0]
|
||||||
|
targetEmbedding := targetResponse.Data[0]
|
||||||
|
|
||||||
|
similarity, err := queryEmbedding.DotProduct(&targetEmbedding)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Error calculating dot product:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("The similarity score between the query and the target is %f", similarity)
|
||||||
|
}
|
||||||
|
|
||||||
|
```
|
||||||
|
</detail>
|
||||||
|
|
||||||
<details>
|
<details>
|
||||||
<summary>Azure OpenAI Embeddings</summary>
|
<summary>Azure OpenAI Embeddings</summary>
|
||||||
|
|
||||||
|
|||||||
@@ -4,10 +4,13 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"math"
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrVectorLengthMismatch = errors.New("vector length mismatch")
|
||||||
|
|
||||||
// EmbeddingModel enumerates the models which can be used
|
// EmbeddingModel enumerates the models which can be used
|
||||||
// to generate Embedding vectors.
|
// to generate Embedding vectors.
|
||||||
type EmbeddingModel int
|
type EmbeddingModel int
|
||||||
@@ -124,6 +127,23 @@ type Embedding struct {
|
|||||||
Index int `json:"index"`
|
Index int `json:"index"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DotProduct calculates the dot product of the embedding vector with another
|
||||||
|
// embedding vector. Both vectors must have the same length; otherwise, an
|
||||||
|
// ErrVectorLengthMismatch is returned. The method returns the calculated dot
|
||||||
|
// product as a float32 value.
|
||||||
|
func (e *Embedding) DotProduct(other *Embedding) (float32, error) {
|
||||||
|
if len(e.Embedding) != len(other.Embedding) {
|
||||||
|
return 0, ErrVectorLengthMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
var dotProduct float32
|
||||||
|
for i := range e.Embedding {
|
||||||
|
dotProduct += e.Embedding[i] * other.Embedding[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
return dotProduct, nil
|
||||||
|
}
|
||||||
|
|
||||||
// EmbeddingResponse is the response from a Create embeddings request.
|
// EmbeddingResponse is the response from a Create embeddings request.
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Object string `json:"object"`
|
Object string `json:"object"`
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"math"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -233,3 +235,39 @@ func TestEmbeddingResponseBase64_ToEmbeddingResponse(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDotProduct(t *testing.T) {
|
||||||
|
v1 := &Embedding{Embedding: []float32{1, 2, 3}}
|
||||||
|
v2 := &Embedding{Embedding: []float32{2, 4, 6}}
|
||||||
|
expected := float32(28.0)
|
||||||
|
|
||||||
|
result, err := v1.DotProduct(v2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if math.Abs(float64(result-expected)) > 1e-12 {
|
||||||
|
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
||||||
|
v2 = &Embedding{Embedding: []float32{0, 1, 0}}
|
||||||
|
expected = float32(0.0)
|
||||||
|
|
||||||
|
result, err = v1.DotProduct(v2)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if math.Abs(float64(result-expected)) > 1e-12 {
|
||||||
|
t.Errorf("Unexpected result. Expected: %v, but got %v", expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test for VectorLengthMismatchError
|
||||||
|
v1 = &Embedding{Embedding: []float32{1, 0, 0}}
|
||||||
|
v2 = &Embedding{Embedding: []float32{0, 1}}
|
||||||
|
_, err = v1.DotProduct(v2)
|
||||||
|
if !errors.Is(err, ErrVectorLengthMismatch) {
|
||||||
|
t.Errorf("Expected Vector Length Mismatch Error, but got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user