package service import ( "ai-search/service/logger" "context" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/sashabaranov/go-openai" ) func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) { token = GetSettings().OpenAIAPIKey endpoint = GetSettings().OpenAIEndpint if mode == "chat" { token = GetSettings().OpenAIChatAPIKey endpoint = GetSettings().OpenAIChatEndpoint } return } type Queue struct { data []time.Duration length int } func NewQueue(length int, defaultValue int) *Queue { data := make([]time.Duration, 0, length) for i := 0; i < length; i++ { data = append(data, time.Duration(defaultValue)*time.Millisecond) } return &Queue{ data: data, length: length, } } func (q *Queue) Add(value time.Duration) { if len(q.data) >= q.length { q.data = q.data[1:] } q.data = append(q.data, value) } func (q *Queue) Avg(k ...int) time.Duration { param := 1 if len(k) > 0 { param = k[0] } total := time.Duration(0) count := 0 for _, value := range q.data { if value != 0 { total += value count++ } } if count == 0 { return time.Duration(0) } ans := total * time.Duration(param) / time.Duration(count) if ans > time.Duration(20)*time.Millisecond { return time.Duration(20) * time.Millisecond } return ans } func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string { result := "" if resp == nil { logger.Logger(c).Error("stream resp is nil") return result } _, ok := c.Writer.(http.Flusher) if !ok { logger.Logger(c).Panic("server not support") } defer func() { c.Writer.Flush() }() queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay) ch := make(chan rune, 1024) go func(c *gin.Context, msgChan chan rune) { lastTime := time.Now() for { line, err := resp.Recv() if err != nil { close(msgChan) logger.Logger(c).WithError(err).Error("read openai completion line error") return } if len(line.Choices[0].Delta.Content) == 0 { continue } nowTime := time.Now() division := strings.Count(line.Choices[0].Delta.Content, "") for _, v := range line.Choices[0].Delta.Content { msgChan <- v } during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) * time.Millisecond)) / time.Duration(division) queue.Add(during) lastTime = nowTime } }(c, ch) for char := range ch { str := string(char) _, err := c.Writer.WriteString(str) result += str if err != nil { logger.Logger(c).WithError(err).Error("write string to client error") return "" } c.Writer.Flush() time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符,一个UTF8字符是3个长度,试图让一个单词等于一个汉字 } logger.Logger(c).Info("finish stream text to client") return result }