ai-search/service/helper.go
2024-02-03 00:49:45 +08:00

130 lines
2.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

package service
import (
"ai-search/service/logger"
"context"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
)
func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) {
token = GetSettings().OpenAIAPIKey
endpoint = GetSettings().OpenAIEndpint
if mode == "chat" {
token = GetSettings().OpenAIChatAPIKey
endpoint = GetSettings().OpenAIChatEndpoint
}
return
}
type Queue struct {
data []time.Duration
length int
}
func NewQueue(length int, defaultValue int) *Queue {
data := make([]time.Duration, 0, length)
for i := 0; i < length; i++ {
data = append(data, time.Duration(defaultValue)*time.Millisecond)
}
return &Queue{
data: data,
length: length,
}
}
func (q *Queue) Add(value time.Duration) {
if len(q.data) >= q.length {
q.data = q.data[1:]
}
q.data = append(q.data, value)
}
func (q *Queue) Avg(k ...int) time.Duration {
param := 1
if len(k) > 0 {
param = k[0]
}
total := time.Duration(0)
count := 0
for _, value := range q.data {
if value != 0 {
total += value
count++
}
}
if count == 0 {
return time.Duration(0)
}
ans := total * time.Duration(param) / time.Duration(count)
if ans > time.Duration(20)*time.Millisecond {
return time.Duration(20) * time.Millisecond
}
return ans
}
func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string {
result := ""
if resp == nil {
logger.Logger(c).Error("stream resp is nil")
return result
}
_, ok := c.Writer.(http.Flusher)
if !ok {
logger.Logger(c).Panic("server not support")
}
defer func() {
c.Writer.Flush()
}()
queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay)
ch := make(chan rune, 1024)
go func(c *gin.Context, msgChan chan rune) {
lastTime := time.Now()
for {
line, err := resp.Recv()
if err != nil {
close(msgChan)
logger.Logger(c).WithError(err).Error("read openai completion line error")
return
}
if len(line.Choices[0].Delta.Content) == 0 {
continue
}
nowTime := time.Now()
division := strings.Count(line.Choices[0].Delta.Content, "")
for _, v := range line.Choices[0].Delta.Content {
msgChan <- v
}
during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) *
time.Millisecond)) / time.Duration(division)
queue.Add(during)
lastTime = nowTime
}
}(c, ch)
for char := range ch {
str := string(char)
_, err := c.Writer.WriteString(str)
result += str
if err != nil {
logger.Logger(c).WithError(err).Error("write string to client error")
return ""
}
c.Writer.Flush()
time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符一个UTF8字符是3个长度试图让一个单词等于一个汉字
}
logger.Logger(c).Info("finish stream text to client")
return result
}