130 lines
2.8 KiB
Go
130 lines
2.8 KiB
Go
package service
|
||
|
||
import (
|
||
"ai-search/service/logger"
|
||
"context"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/sashabaranov/go-openai"
|
||
)
|
||
|
||
func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) {
|
||
token = GetSettings().OpenAIAPIKey
|
||
endpoint = GetSettings().OpenAIEndpint
|
||
if mode == "chat" {
|
||
token = GetSettings().OpenAIChatAPIKey
|
||
endpoint = GetSettings().OpenAIChatEndpoint
|
||
}
|
||
|
||
return
|
||
}
|
||
|
||
type Queue struct {
|
||
data []time.Duration
|
||
length int
|
||
}
|
||
|
||
func NewQueue(length int, defaultValue int) *Queue {
|
||
data := make([]time.Duration, 0, length)
|
||
for i := 0; i < length; i++ {
|
||
data = append(data, time.Duration(defaultValue)*time.Millisecond)
|
||
}
|
||
return &Queue{
|
||
data: data,
|
||
length: length,
|
||
}
|
||
}
|
||
|
||
func (q *Queue) Add(value time.Duration) {
|
||
if len(q.data) >= q.length {
|
||
q.data = q.data[1:]
|
||
}
|
||
q.data = append(q.data, value)
|
||
}
|
||
|
||
func (q *Queue) Avg(k ...int) time.Duration {
|
||
param := 1
|
||
if len(k) > 0 {
|
||
param = k[0]
|
||
}
|
||
total := time.Duration(0)
|
||
count := 0
|
||
for _, value := range q.data {
|
||
if value != 0 {
|
||
total += value
|
||
count++
|
||
}
|
||
}
|
||
if count == 0 {
|
||
return time.Duration(0)
|
||
}
|
||
ans := total * time.Duration(param) / time.Duration(count)
|
||
if ans > time.Duration(20)*time.Millisecond {
|
||
return time.Duration(20) * time.Millisecond
|
||
}
|
||
return ans
|
||
}
|
||
|
||
func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string {
|
||
result := ""
|
||
|
||
if resp == nil {
|
||
logger.Logger(c).Error("stream resp is nil")
|
||
return result
|
||
}
|
||
|
||
_, ok := c.Writer.(http.Flusher)
|
||
if !ok {
|
||
logger.Logger(c).Panic("server not support")
|
||
}
|
||
defer func() {
|
||
c.Writer.Flush()
|
||
}()
|
||
|
||
queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay)
|
||
ch := make(chan rune, 1024)
|
||
|
||
go func(c *gin.Context, msgChan chan rune) {
|
||
lastTime := time.Now()
|
||
for {
|
||
line, err := resp.Recv()
|
||
if err != nil {
|
||
close(msgChan)
|
||
logger.Logger(c).WithError(err).Error("read openai completion line error")
|
||
return
|
||
}
|
||
if len(line.Choices[0].Delta.Content) == 0 {
|
||
continue
|
||
}
|
||
nowTime := time.Now()
|
||
|
||
division := strings.Count(line.Choices[0].Delta.Content, "")
|
||
for _, v := range line.Choices[0].Delta.Content {
|
||
msgChan <- v
|
||
}
|
||
|
||
during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) *
|
||
time.Millisecond)) / time.Duration(division)
|
||
queue.Add(during)
|
||
lastTime = nowTime
|
||
}
|
||
}(c, ch)
|
||
|
||
for char := range ch {
|
||
str := string(char)
|
||
_, err := c.Writer.WriteString(str)
|
||
result += str
|
||
if err != nil {
|
||
logger.Logger(c).WithError(err).Error("write string to client error")
|
||
return ""
|
||
}
|
||
c.Writer.Flush()
|
||
time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符,一个UTF8字符是3个长度,试图让一个单词等于一个汉字
|
||
}
|
||
logger.Logger(c).Info("finish stream text to client")
|
||
return result
|
||
}
|