package service import ( "ai-search/service/logger" "context" "encoding/json" "fmt" "net/http" "strings" "github.com/gin-gonic/gin" "github.com/sashabaranov/go-openai" ) func SearchHandler(c *gin.Context) { searchReq := &SearchReq{} if err := c.Copy().ShouldBindJSON(searchReq); err != nil { ErrResp[gin.H](c, nil, "error", http.StatusBadRequest) return } cachedResult, err := getCacheResp(c, searchReq.SearchUUID) if err == nil { logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query) c.String(http.StatusOK, cachedResult) return } if searchReq.Query == "" && searchReq.SearchUUID == "" { ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest) return } if searchReq.Query == "" && searchReq.SearchUUID != "" { ErrResp[gin.H](c, nil, "content is gone", http.StatusGone) return } c.Writer.Header().Set("Content-Type", "text/event-stream") c.Writer.Header().Set("Cache-Control", "no-cache") c.Writer.Header().Set("Connection", "keep-alive") c.Writer.Header().Set("Access-Control-Allow-Origin", "*") cli := NewSearchClient() searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount) if err != nil { logger.Logger(c).WithError(err).Errorf("client.Search error") return } ss := &Sources{} ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID) originReq := &openai.ChatCompletionRequest{ Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleSystem, Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)), }, { Role: openai.ChatMessageRoleUser, Content: searchReq.Query, }, }, Stream: true, } apiKey, endpoint := getOpenAIConfig(c, "chat") conf := openai.DefaultConfig(apiKey) conf.BaseURL = endpoint client := openai.NewClientWithConfig(conf) request := openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo, Messages: originReq.Messages, Temperature: GetSettings().RAGParams.Temperature, MaxTokens: GetSettings().RAGParams.MaxTokens, Stream: true, } resp, err := client.CreateChatCompletionStream( context.Background(), request, ) if err != nil { logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error") } relatedStrChan := make(chan string) defer close(relatedStrChan) go func() { relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss) }() finalResult := streamSearchItemResp(c, []string{ ss.ToString(), "\n\n__LLM_RESPONSE__\n\n", }) finalResult = finalResult + streamResp(c, resp) finalResult = finalResult + streamSearchItemResp(c, []string{ "\n\n__RELATED_QUESTIONS__\n\n", // `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`, <-relatedStrChan, }) GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime) logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query) } func getCacheResp(c *gin.Context, searchUUID string) (string, error) { ans, err := GetSearchCache().Get([]byte(searchUUID)) if err != nil { return "", err } if ans != nil { cachedResult := &cachedResult{} cachedResult.FromBytes(ans) return cachedResult.Result, nil } return "", fmt.Errorf("cache not found") } func streamSearchItemResp(c *gin.Context, t []string) string { result := "" _, ok := c.Writer.(http.Flusher) if !ok { logger.Logger(c).Panic("server not support") } defer func() { c.Writer.Flush() }() for _, line := range t { _, err := c.Writer.WriteString(line) result += line if err != nil { logger.Logger(c).WithError(err).Error("write string to client error") return "" } } logger.Logger(c).Info("finish stream text to client") return result } func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string { apiKey, endpoint := getOpenAIConfig(c, "chat") conf := openai.DefaultConfig(apiKey) conf.BaseURL = endpoint client := openai.NewClientWithConfig(conf) request := openai.ChatCompletionRequest{ Model: openai.GPT3Dot5Turbo, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query, }, }, Temperature: GetSettings().RAGParams.MoreQuestionsTemperature, MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens, } resp, err := client.CreateChatCompletion( context.Background(), request, ) if err != nil { logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error") } mode := 1 cs := strings.Split(resp.Choices[0].Message.Content, ". ") if len(cs) == 1 { cs = strings.Split(resp.Choices[0].Message.Content, "- ") mode = 2 } rq := []string{} for i, line := range cs { if len(line) <= 2 { continue } if i != len(cs)-1 && mode == 1 { line = line[:len(line)-1] } rq = append(rq, line) } return parseRelatedQuestionsResp(rq) } func parseRelatedQuestionsResp(qs []string) string { q := []struct { Question string `json:"question"` }{} for _, line := range qs { if len(strings.Trim(line, " ")) <= 2 { continue } q = append(q, struct { Question string `json:"question"` }{line}) } rawBytes, _ := json.Marshal(q) return string(rawBytes) } func getSearchContext(ss *Sources) string { ans := "" for i, ctx := range ss.Contexts { ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n" } return ans }