ai-search/service/search_handler.go
2024-02-03 00:49:45 +08:00

215 lines
5.5 KiB
Go

package service
import (
"ai-search/service/logger"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
)
func SearchHandler(c *gin.Context) {
searchReq := &SearchReq{}
if err := c.Copy().ShouldBindJSON(searchReq); err != nil {
ErrResp[gin.H](c, nil, "error", http.StatusBadRequest)
return
}
cachedResult, err := getCacheResp(c, searchReq.SearchUUID)
if err == nil {
logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query)
c.String(http.StatusOK, cachedResult)
return
}
if searchReq.Query == "" && searchReq.SearchUUID == "" {
ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest)
return
}
if searchReq.Query == "" && searchReq.SearchUUID != "" {
ErrResp[gin.H](c, nil, "content is gone", http.StatusGone)
return
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
cli := NewSearchClient()
searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.Search error")
return
}
ss := &Sources{}
ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID)
originReq := &openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)),
},
{
Role: openai.ChatMessageRoleUser,
Content: searchReq.Query,
},
},
Stream: true,
}
apiKey, endpoint := getOpenAIConfig(c, "chat")
conf := openai.DefaultConfig(apiKey)
conf.BaseURL = endpoint
client := openai.NewClientWithConfig(conf)
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: originReq.Messages,
Temperature: GetSettings().RAGParams.Temperature,
MaxTokens: GetSettings().RAGParams.MaxTokens,
Stream: true,
}
resp, err := client.CreateChatCompletionStream(
context.Background(),
request,
)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error")
}
relatedStrChan := make(chan string)
defer close(relatedStrChan)
go func() {
relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss)
}()
finalResult := streamSearchItemResp(c, []string{
ss.ToString(),
"\n\n__LLM_RESPONSE__\n\n",
})
finalResult = finalResult + streamResp(c, resp)
finalResult = finalResult + streamSearchItemResp(c, []string{
"\n\n__RELATED_QUESTIONS__\n\n",
// `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`,
<-relatedStrChan,
})
GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime)
logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query)
}
func getCacheResp(c *gin.Context, searchUUID string) (string, error) {
ans, err := GetSearchCache().Get([]byte(searchUUID))
if err != nil {
return "", err
}
if ans != nil {
cachedResult := &cachedResult{}
cachedResult.FromBytes(ans)
return cachedResult.Result, nil
}
return "", fmt.Errorf("cache not found")
}
func streamSearchItemResp(c *gin.Context, t []string) string {
result := ""
_, ok := c.Writer.(http.Flusher)
if !ok {
logger.Logger(c).Panic("server not support")
}
defer func() {
c.Writer.Flush()
}()
for _, line := range t {
_, err := c.Writer.WriteString(line)
result += line
if err != nil {
logger.Logger(c).WithError(err).Error("write string to client error")
return ""
}
}
logger.Logger(c).Info("finish stream text to client")
return result
}
func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string {
apiKey, endpoint := getOpenAIConfig(c, "chat")
conf := openai.DefaultConfig(apiKey)
conf.BaseURL = endpoint
client := openai.NewClientWithConfig(conf)
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query,
},
},
Temperature: GetSettings().RAGParams.MoreQuestionsTemperature,
MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens,
}
resp, err := client.CreateChatCompletion(
context.Background(),
request,
)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error")
}
mode := 1
cs := strings.Split(resp.Choices[0].Message.Content, ". ")
if len(cs) == 1 {
cs = strings.Split(resp.Choices[0].Message.Content, "- ")
mode = 2
}
rq := []string{}
for i, line := range cs {
if len(line) <= 2 {
continue
}
if i != len(cs)-1 && mode == 1 {
line = line[:len(line)-1]
}
rq = append(rq, line)
}
return parseRelatedQuestionsResp(rq)
}
func parseRelatedQuestionsResp(qs []string) string {
q := []struct {
Question string `json:"question"`
}{}
for _, line := range qs {
if len(strings.Trim(line, " ")) <= 2 {
continue
}
q = append(q, struct {
Question string `json:"question"`
}{line})
}
rawBytes, _ := json.Marshal(q)
return string(rawBytes)
}
func getSearchContext(ss *Sources) string {
ans := ""
for i, ctx := range ss.Contexts {
ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n"
}
return ans
}