215 lines
5.5 KiB
Go
215 lines
5.5 KiB
Go
package service
|
|
|
|
import (
|
|
"ai-search/service/logger"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
func SearchHandler(c *gin.Context) {
|
|
searchReq := &SearchReq{}
|
|
if err := c.Copy().ShouldBindJSON(searchReq); err != nil {
|
|
ErrResp[gin.H](c, nil, "error", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
cachedResult, err := getCacheResp(c, searchReq.SearchUUID)
|
|
if err == nil {
|
|
logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query)
|
|
c.String(http.StatusOK, cachedResult)
|
|
return
|
|
}
|
|
|
|
if searchReq.Query == "" && searchReq.SearchUUID == "" {
|
|
ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
if searchReq.Query == "" && searchReq.SearchUUID != "" {
|
|
ErrResp[gin.H](c, nil, "content is gone", http.StatusGone)
|
|
return
|
|
}
|
|
|
|
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
|
c.Writer.Header().Set("Cache-Control", "no-cache")
|
|
c.Writer.Header().Set("Connection", "keep-alive")
|
|
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
|
cli := NewSearchClient()
|
|
searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount)
|
|
if err != nil {
|
|
logger.Logger(c).WithError(err).Errorf("client.Search error")
|
|
return
|
|
}
|
|
ss := &Sources{}
|
|
ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID)
|
|
|
|
originReq := &openai.ChatCompletionRequest{
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleSystem,
|
|
Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)),
|
|
},
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: searchReq.Query,
|
|
},
|
|
},
|
|
Stream: true,
|
|
}
|
|
|
|
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
|
conf := openai.DefaultConfig(apiKey)
|
|
conf.BaseURL = endpoint
|
|
client := openai.NewClientWithConfig(conf)
|
|
request := openai.ChatCompletionRequest{
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: originReq.Messages,
|
|
Temperature: GetSettings().RAGParams.Temperature,
|
|
MaxTokens: GetSettings().RAGParams.MaxTokens,
|
|
Stream: true,
|
|
}
|
|
|
|
resp, err := client.CreateChatCompletionStream(
|
|
context.Background(),
|
|
request,
|
|
)
|
|
if err != nil {
|
|
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error")
|
|
}
|
|
|
|
relatedStrChan := make(chan string)
|
|
defer close(relatedStrChan)
|
|
go func() {
|
|
relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss)
|
|
}()
|
|
|
|
finalResult := streamSearchItemResp(c, []string{
|
|
ss.ToString(),
|
|
"\n\n__LLM_RESPONSE__\n\n",
|
|
})
|
|
finalResult = finalResult + streamResp(c, resp)
|
|
finalResult = finalResult + streamSearchItemResp(c, []string{
|
|
"\n\n__RELATED_QUESTIONS__\n\n",
|
|
// `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`,
|
|
<-relatedStrChan,
|
|
})
|
|
|
|
GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime)
|
|
logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query)
|
|
}
|
|
|
|
func getCacheResp(c *gin.Context, searchUUID string) (string, error) {
|
|
ans, err := GetSearchCache().Get([]byte(searchUUID))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if ans != nil {
|
|
cachedResult := &cachedResult{}
|
|
cachedResult.FromBytes(ans)
|
|
return cachedResult.Result, nil
|
|
}
|
|
|
|
return "", fmt.Errorf("cache not found")
|
|
}
|
|
|
|
func streamSearchItemResp(c *gin.Context, t []string) string {
|
|
result := ""
|
|
|
|
_, ok := c.Writer.(http.Flusher)
|
|
if !ok {
|
|
logger.Logger(c).Panic("server not support")
|
|
}
|
|
defer func() {
|
|
c.Writer.Flush()
|
|
}()
|
|
|
|
for _, line := range t {
|
|
_, err := c.Writer.WriteString(line)
|
|
result += line
|
|
if err != nil {
|
|
logger.Logger(c).WithError(err).Error("write string to client error")
|
|
return ""
|
|
}
|
|
}
|
|
logger.Logger(c).Info("finish stream text to client")
|
|
return result
|
|
}
|
|
|
|
func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string {
|
|
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
|
conf := openai.DefaultConfig(apiKey)
|
|
conf.BaseURL = endpoint
|
|
client := openai.NewClientWithConfig(conf)
|
|
request := openai.ChatCompletionRequest{
|
|
Model: openai.GPT3Dot5Turbo,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query,
|
|
},
|
|
},
|
|
Temperature: GetSettings().RAGParams.MoreQuestionsTemperature,
|
|
MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens,
|
|
}
|
|
|
|
resp, err := client.CreateChatCompletion(
|
|
context.Background(),
|
|
request,
|
|
)
|
|
if err != nil {
|
|
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error")
|
|
}
|
|
|
|
mode := 1
|
|
|
|
cs := strings.Split(resp.Choices[0].Message.Content, ". ")
|
|
if len(cs) == 1 {
|
|
cs = strings.Split(resp.Choices[0].Message.Content, "- ")
|
|
mode = 2
|
|
}
|
|
rq := []string{}
|
|
for i, line := range cs {
|
|
if len(line) <= 2 {
|
|
continue
|
|
}
|
|
if i != len(cs)-1 && mode == 1 {
|
|
line = line[:len(line)-1]
|
|
}
|
|
rq = append(rq, line)
|
|
}
|
|
|
|
return parseRelatedQuestionsResp(rq)
|
|
}
|
|
|
|
func parseRelatedQuestionsResp(qs []string) string {
|
|
q := []struct {
|
|
Question string `json:"question"`
|
|
}{}
|
|
for _, line := range qs {
|
|
if len(strings.Trim(line, " ")) <= 2 {
|
|
continue
|
|
}
|
|
q = append(q, struct {
|
|
Question string `json:"question"`
|
|
}{line})
|
|
}
|
|
rawBytes, _ := json.Marshal(q)
|
|
return string(rawBytes)
|
|
}
|
|
|
|
func getSearchContext(ss *Sources) string {
|
|
ans := ""
|
|
for i, ctx := range ss.Contexts {
|
|
ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n"
|
|
}
|
|
return ans
|
|
}
|