first commit

This commit is contained in:
Vaala Cat
2024-02-02 23:21:43 +08:00
commit 0b86be3426
52 changed files with 4390 additions and 0 deletions

18
service/cache.go Normal file
View File

@@ -0,0 +1,18 @@
package service
import (
"github.com/coocood/freecache"
)
var (
searchCache *freecache.Cache
)
func InitCache() {
cacheSize := GetSettings().CacheSize * 1024 * 1024 // 100 MB
searchCache = freecache.NewCache(cacheSize)
}
func GetSearchCache() *freecache.Cache {
return searchCache
}

27
service/const.go Normal file
View File

@@ -0,0 +1,27 @@
package service
const (
HeaderKey = "header"
AppEntryKey = "entry"
UIDKey = "uid"
EndpointKey = "endpoint"
UserInfoKey = "userinfo"
TokenKey = "token"
AuthorizationKey = "authorization"
SetAuthorizationKey = "x-authorization-token"
TraceIDKey = "x-vaala-trace-id"
ClientRequestIDKey = "x-client-request-id"
SessionKey = "x-vaala-session"
RegistrationSessionKey = "x-vaala-registration-session"
LoginSessionKey = "x-vaala-login-session"
)
const (
RespSuccess = "success"
ValueNone = "none"
)
const (
LangEN = "en"
LangZH = "zh"
)

129
service/helper.go Normal file
View File

@@ -0,0 +1,129 @@
package service
import (
"ai-search/service/logger"
"context"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
)
func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) {
token = GetSettings().OpenAIAPIKey
endpoint = GetSettings().OpenAIEndpint
if mode == "chat" {
token = GetSettings().OpenAIChatAPIKey
endpoint = GetSettings().OpenAIChatEndpoint
}
return
}
type Queue struct {
data []time.Duration
length int
}
func NewQueue(length int, defaultValue int) *Queue {
data := make([]time.Duration, 0, length)
for i := 0; i < length; i++ {
data = append(data, time.Duration(defaultValue)*time.Millisecond)
}
return &Queue{
data: data,
length: length,
}
}
func (q *Queue) Add(value time.Duration) {
if len(q.data) >= q.length {
q.data = q.data[1:]
}
q.data = append(q.data, value)
}
func (q *Queue) Avg(k ...int) time.Duration {
param := 1
if len(k) > 0 {
param = k[0]
}
total := time.Duration(0)
count := 0
for _, value := range q.data {
if value != 0 {
total += value
count++
}
}
if count == 0 {
return time.Duration(0)
}
ans := total * time.Duration(param) / time.Duration(count)
if ans > time.Duration(20)*time.Millisecond {
return time.Duration(20) * time.Millisecond
}
return ans
}
func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string {
result := ""
if resp == nil {
logger.Logger(c).Error("stream resp is nil")
return result
}
_, ok := c.Writer.(http.Flusher)
if !ok {
logger.Logger(c).Panic("server not support")
}
defer func() {
c.Writer.Flush()
}()
queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay)
ch := make(chan rune, 1024)
go func(c *gin.Context, msgChan chan rune) {
lastTime := time.Now()
for {
line, err := resp.Recv()
if err != nil {
close(msgChan)
logger.Logger(c).WithError(err).Error("read openai completion line error")
return
}
if len(line.Choices[0].Delta.Content) == 0 {
continue
}
nowTime := time.Now()
division := strings.Count(line.Choices[0].Delta.Content, "")
for _, v := range line.Choices[0].Delta.Content {
msgChan <- v
}
during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) *
time.Millisecond)) / time.Duration(division)
queue.Add(during)
lastTime = nowTime
}
}(c, ch)
for char := range ch {
str := string(char)
_, err := c.Writer.WriteString(str)
result += str
if err != nil {
logger.Logger(c).WithError(err).Error("write string to client error")
return ""
}
c.Writer.Flush()
time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符一个UTF8字符是3个长度试图让一个单词等于一个汉字
}
logger.Logger(c).Info("finish stream text to client")
return result
}

11
service/logger/logger.go Normal file
View File

@@ -0,0 +1,11 @@
package logger
import (
"context"
"github.com/sirupsen/logrus"
)
func Logger(c context.Context) *logrus.Entry {
return logrus.WithContext(c)
}

51
service/response.go Normal file
View File

@@ -0,0 +1,51 @@
package service
import (
"net/http"
"github.com/gin-gonic/gin"
)
type CommonResp interface {
gin.H
}
func OKResp[T CommonResp](c *gin.Context, origin *T) {
if c.ContentType() == "application/x-protobuf" {
c.ProtoBuf(http.StatusOK, origin)
} else {
c.JSON(http.StatusOK, OK(RespSuccess).WithBody(origin))
}
}
func OKRespWithJsonMarshal[T CommonResp](c *gin.Context, origin *T) {
c.JSON(http.StatusOK, OK(RespSuccess).WithBody(origin))
}
func ErrResp[T CommonResp](c *gin.Context, origin *T, err string, errCode ...int) {
if c.ContentType() == "application/x-protobuf" {
c.ProtoBuf(http.StatusInternalServerError, origin)
} else {
if len(errCode) > 0 {
c.JSON(errCode[0], Err(err).WithBody(origin))
return
}
c.JSON(http.StatusOK, Err(err).WithBody(origin))
}
}
func ErrUnAuthorized(c *gin.Context, err string) {
if c.ContentType() == "application/x-protobuf" {
c.ProtoBuf(http.StatusUnauthorized, nil)
} else {
c.JSON(http.StatusUnauthorized, Err(err).WithBody(nil))
}
}
func ErrNotFound(c *gin.Context, err string) {
if c.ContentType() == "application/x-protobuf" {
c.ProtoBuf(http.StatusNotFound, nil)
} else {
c.JSON(http.StatusNotFound, Err(err).WithBody(nil))
}
}

55
service/result.go Normal file
View File

@@ -0,0 +1,55 @@
package service
import (
"github.com/gin-gonic/gin"
)
type Result struct {
Code int `json:"code,omitempty"`
Msg string `json:"msg,omitempty"`
Data gin.H `json:"data,omitempty"`
Body interface{} `json:"body,omitempty"`
}
func (r *Result) WithMsg(message string) *Result {
r.Msg = message
return r
}
func (r *Result) WithData(data gin.H) *Result {
r.Data = data
return r
}
func (r *Result) WithKeyValue(key string, value interface{}) *Result {
if r.Data == nil {
r.Data = gin.H{}
}
r.Data[key] = value
return r
}
func (r *Result) WithBody(body interface{}) *Result {
r.Body = body
return r
}
func newResult(code int, msg string) *Result {
return &Result{
Code: code,
Msg: msg,
Data: nil,
}
}
func OK(msg string) *Result {
return newResult(200, msg)
}
func Err(msg string) *Result {
return newResult(500, msg)
}
func UnAuth(msg string) *Result {
return newResult(401, msg)
}

51
service/rpc.go Normal file
View File

@@ -0,0 +1,51 @@
package service
import (
"ai-search/service/logger"
"context"
"fmt"
"github.com/imroc/req/v3"
)
type SearchClient interface {
Search(c context.Context, query string, nums int) (SearchResp, error)
}
type searchClient struct {
URL string
}
type SearchResp struct {
Results []SearchResult `json:"results"`
}
type SearchResult struct {
Body string `json:"body"`
Href string `json:"href"`
Title string `json:"title"`
}
func NewSearchClient() SearchClient {
return &searchClient{
URL: GetSettings().RPCEndpoints.SearchURL,
}
}
func (s *searchClient) Search(c context.Context, query string, nums int) (SearchResp, error) {
resp := SearchResp{}
_, err := req.C().R().
SetFormData(map[string]string{
"q": query,
"max_results": fmt.Sprintf("%d", nums),
}).
SetContentType("application/x-www-form-urlencoded").
SetSuccessResult(&resp).
Post(s.URL)
if err != nil {
logger.Logger(c).Error(err)
}
return resp, err
}

214
service/search_handler.go Normal file
View File

@@ -0,0 +1,214 @@
package service
import (
"ai-search/service/logger"
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/sashabaranov/go-openai"
)
func SearchHandler(c *gin.Context) {
searchReq := &SearchReq{}
if err := c.Copy().ShouldBindJSON(searchReq); err != nil {
ErrResp[gin.H](c, nil, "error", http.StatusBadRequest)
return
}
cachedResult, err := getCacheResp(c, searchReq.SearchUUID)
if err == nil {
logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query)
c.String(http.StatusOK, cachedResult)
return
}
if searchReq.Query == "" && searchReq.SearchUUID == "" {
ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest)
return
}
if searchReq.Query == "" && searchReq.SearchUUID != "" {
ErrResp[gin.H](c, nil, "content is gone", http.StatusGone)
return
}
c.Writer.Header().Set("Content-Type", "text/event-stream")
c.Writer.Header().Set("Cache-Control", "no-cache")
c.Writer.Header().Set("Connection", "keep-alive")
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
cli := NewSearchClient()
searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.Search error")
return
}
ss := &Sources{}
ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID)
originReq := &openai.ChatCompletionRequest{
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleSystem,
Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)),
},
{
Role: openai.ChatMessageRoleUser,
Content: searchReq.Query,
},
},
Stream: true,
}
apiKey, endpoint := getOpenAIConfig(c, "chat")
conf := openai.DefaultConfig(apiKey)
conf.BaseURL = endpoint
client := openai.NewClientWithConfig(conf)
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: originReq.Messages,
Temperature: GetSettings().RAGParams.Temperature,
MaxTokens: GetSettings().RAGParams.MaxTokens,
Stream: true,
}
resp, err := client.CreateChatCompletionStream(
context.Background(),
request,
)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error")
}
relatedStrChan := make(chan string)
defer close(relatedStrChan)
go func() {
relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss)
}()
finalResult := streamSearchItemResp(c, []string{
ss.ToString(),
"\n\n__LLM_RESPONSE__\n\n",
})
finalResult = finalResult + streamResp(c, resp)
finalResult = finalResult + streamSearchItemResp(c, []string{
"\n\n__RELATED_QUESTIONS__\n\n",
// `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`,
<-relatedStrChan,
})
GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime)
logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query)
}
func getCacheResp(c *gin.Context, searchUUID string) (string, error) {
ans, err := GetSearchCache().Get([]byte(searchUUID))
if err != nil {
return "", err
}
if ans != nil {
cachedResult := &cachedResult{}
cachedResult.FromBytes(ans)
return cachedResult.Result, nil
}
return "", fmt.Errorf("cache not found")
}
func streamSearchItemResp(c *gin.Context, t []string) string {
result := ""
_, ok := c.Writer.(http.Flusher)
if !ok {
logger.Logger(c).Panic("server not support")
}
defer func() {
c.Writer.Flush()
}()
for _, line := range t {
_, err := c.Writer.WriteString(line)
result += line
if err != nil {
logger.Logger(c).WithError(err).Error("write string to client error")
return ""
}
}
logger.Logger(c).Info("finish stream text to client")
return result
}
func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string {
apiKey, endpoint := getOpenAIConfig(c, "chat")
conf := openai.DefaultConfig(apiKey)
conf.BaseURL = endpoint
client := openai.NewClientWithConfig(conf)
request := openai.ChatCompletionRequest{
Model: openai.GPT3Dot5Turbo,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query,
},
},
Temperature: GetSettings().RAGParams.MoreQuestionsTemperature,
MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens,
}
resp, err := client.CreateChatCompletion(
context.Background(),
request,
)
if err != nil {
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error")
}
mode := 1
cs := strings.Split(resp.Choices[0].Message.Content, ". ")
if len(cs) == 1 {
cs = strings.Split(resp.Choices[0].Message.Content, "- ")
mode = 2
}
rq := []string{}
for i, line := range cs {
if len(line) <= 2 {
continue
}
if i != len(cs)-1 && mode == 1 {
line = line[:len(line)-1]
}
rq = append(rq, line)
}
return parseRelatedQuestionsResp(rq)
}
func parseRelatedQuestionsResp(qs []string) string {
q := []struct {
Question string `json:"question"`
}{}
for _, line := range qs {
if len(strings.Trim(line, " ")) <= 2 {
continue
}
q = append(q, struct {
Question string `json:"question"`
}{line})
}
rawBytes, _ := json.Marshal(q)
return string(rawBytes)
}
func getSearchContext(ss *Sources) string {
ans := ""
for i, ctx := range ss.Contexts {
ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n"
}
return ans
}

69
service/settings.go Normal file
View File

@@ -0,0 +1,69 @@
package service
import (
"ai-search/utils"
"github.com/ilyakaznacheev/cleanenv"
"github.com/joho/godotenv"
)
type AppSettings struct {
ListenAddr string `env:"LISTEN_ADDR" env-default:":8080"`
DBPath string `env:"DB_PATH" env-default:"/litefs/db"`
OpenAIEndpint string `env:"OPENAI_ENDPOINT" env-required:"true"`
OpenAIAPIKey string `env:"OPENAI_API_KEY" env-required:"true"`
OpenAIWriterPrompt string `env:"OPENAI_WRITER_PROMPT" env-default:"您是一名人工智能写作助手,可以根据先前文本的上下文继续现有文本。给予后面的字符比开始的字符更多的权重/优先级。而且绝对一定要尽量多使用中文!将您的回答限制在 200 个字符以内,但请确保构建完整的句子。"`
HttpProxy string `env:"HTTP_PROXY"`
IsDebug bool `env:"DEBUG" env-default:"false"`
ChatMaxTokens int `env:"CHAT_MAX_TOKENS" env-default:"400"`
OpenAIChatEndpoint string `env:"OPENAI_CHAT_ENDPOINT" env-required:"true"`
OpenAIChatAPIKey string `env:"OPENAI_CHAT_API_KEY" env-required:"true"`
OpenAIChatNetworkDelay int `env:"OPENAI_CHAT_NETWORK_DELAY" env-default:"5"`
OpenAIChatQueueLen int `env:"OPENAI_CHAT_QUEUE_LEN" env-default:"10"`
CacheSize int `env:"CACHE_SIZE" env-default:"100"` // in MB
RAGSearchCount int `env:"RAG_SEARCH_COUNT" env-default:"8"`
RAGSearchCacheTime int `env:"RAG_SEARCH_CACHE_TIME" env-default:"1200"` // sec
RPCEndpoints RPCEndpoints `env-prefix:"RPC_"`
Prompts Prompts `env-prefix:"PROMPT_"`
RAGParams RAGParams `env-prefix:"RAG_"`
}
type RPCEndpoints struct {
SearchURL string `env:"SEARCH_URL" env-required:"true"`
}
type Prompts struct {
RAGPath string `env:"RAG_PATH" env-default:""`
MoreQuestionsPath string `env:"MORE_QUESTIONS_PATH" env-default:""`
}
type RAGParams struct {
MaxTokens int `env:"MAX_TOKENS" env-default:"2048"`
MoreQuestionsMaxTokens int `env:"MORE_QUESTIONS_MAX_TOKENS" env-default:"1024"`
Temperature float32 `env:"TEMPERATURE" env-default:"0.9"`
MoreQuestionsTemperature float32 `env:"MORE_QUESTIONS_TEMPERATURE" env-default:"0.7"`
}
var (
appSetting *AppSettings
)
func init() {
godotenv.Load()
conf := &AppSettings{}
if err := cleanenv.ReadEnv(conf); err != nil {
panic(err)
}
appSetting = conf
InitCache()
if len(appSetting.Prompts.MoreQuestionsPath) >= 0 {
moreQuestionsPrompt = utils.GetFileContent(appSetting.Prompts.MoreQuestionsPath)
}
if len(appSetting.Prompts.RAGPath) >= 0 {
ragPrompt = utils.GetFileContent(appSetting.Prompts.RAGPath)
}
}
func GetSettings() *AppSettings {
return appSetting
}

51
service/static.go Normal file
View File

@@ -0,0 +1,51 @@
package service
import (
"embed"
"io/fs"
"net/http"
"strings"
"github.com/gin-contrib/static"
"github.com/gin-gonic/gin"
)
type embedFileSystem struct {
http.FileSystem
}
func (e embedFileSystem) Exists(prefix string, path string) bool {
_, err := e.Open(path)
return err == nil
}
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
fsys, err := fs.Sub(fsEmbed, targetPath)
if err != nil {
panic(err)
}
return embedFileSystem{
FileSystem: http.FS(fsys),
}
}
func HandleStaticFile(router *gin.Engine, f embed.FS) {
root := EmbedFolder(f, "out")
router.Use(static.Serve("/", root))
staticServer := static.Serve("/", root)
router.NoRoute(func(c *gin.Context) {
if c.Request.Method == http.MethodGet &&
!strings.ContainsRune(c.Request.URL.Path, '.') &&
!strings.HasPrefix(c.Request.URL.Path, "/v1/") {
if strings.HasSuffix(c.Request.URL.Path, "/") {
c.Request.URL.Path += "index.html"
staticServer(c)
return
}
if !strings.HasSuffix(c.Request.URL.Path, ".html") {
c.Request.URL.Path += ".html"
staticServer(c)
return
}
}
})
}

126
service/types.go Normal file
View File

@@ -0,0 +1,126 @@
package service
import (
"encoding/json"
"time"
)
var (
ragPrompt = `You are a large language AI assistant built by VaalaCat. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context and cite the context at the end of each sentence if applicable.
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information.
Please cite the contexts with the reference numbers, in the format [citation:x]. If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. Other than code and specific names and citations, your answer must be written in the same language as the question.
Here are the set of contexts:
%s
Remember, use Chinese more, don't blindly repeat the contexts verbatim. And here is the user question:
`
moreQuestionsPrompt = `You are a helpful assistant that helps the user to ask related questions, based on user's original question and the related contexts. Please identify worthwhile topics that can be follow-ups, and write questions no longer than 20 words each. Please make sure that specifics, like events, names, locations, are included in follow up questions so they can be asked standalone. For example, if the original question asks about "the Manhattan project", in the follow up question, do not just say "the project", but use the full name "the Manhattan project". Your related questions must be in the same language as the original question.
Here are the contexts of the question:
%s
Remember, use Chinese more, based on the original question and related contexts, suggest three such further questions. Do NOT repeat the original question. Each related question should be no longer than 20 words. Here is the original question:
`
)
func RagPrompt() string {
return ragPrompt
}
func MoreQuestionsPrompt() string {
return moreQuestionsPrompt
}
type SearchReq struct {
Query string `json:"query"`
SearchUUID string `json:"search_uuid"`
}
type Sources struct {
Query string `json:"query"`
RID string `json:"rid"`
Contexts []Source `json:"contexts"`
}
func (ss *Sources) FromSearchResp(resp *SearchResp, query, rid string) {
ss.Query = query
ss.RID = rid
ctxs := make([]Source, 0)
for _, ctx := range resp.Results {
ctxs = append(ctxs, Source{
ID: ctx.Href,
Name: ctx.Title,
URL: ctx.Href,
Snippet: ctx.Body,
})
}
ss.Contexts = ctxs
}
func (ss *Sources) ToString() string {
rawBytes, _ := json.Marshal(ss)
return string(rawBytes)
}
type Source struct {
ID string `json:"id"`
Name string `json:"name"`
URL string `json:"url"`
IsFamilyFriendly bool `json:"isFamilyFriendly"`
DisplayURL string `json:"displayUrl"`
Snippet string `json:"snippet"`
DeepLinks []DeepLink `json:"deepLinks"`
DateLastCrawled time.Time `json:"dateLastCrawled"`
CachedPageURL string `json:"cachedPageUrl"`
Language string `json:"language"`
PrimaryImageOfPage *PrimaryImage `json:"primaryImageOfPage,omitempty"`
IsNavigational bool `json:"isNavigational"`
}
type DeepLink struct {
Snippet string `json:"snippet"`
Name string `json:"name"`
URL string `json:"url"`
}
type PrimaryImage struct {
ThumbnailURL string `json:"thumbnailUrl"`
Width int `json:"width"`
Height int `json:"height"`
ImageID string `json:"imageId"`
}
func (s *Source) FromSearchResp(resp *SearchResult) {
s.ID = resp.Href
s.Name = resp.Title
s.URL = resp.Href
s.Snippet = resp.Body
}
type cachedResult struct {
SearchUUID string `json:"search_uuid"`
Query string `json:"query"`
Result string `json:"result"`
}
func (cs *cachedResult) FromBytes(rawBytes []byte) {
json.Unmarshal(rawBytes, cs)
}
func (cs *cachedResult) ToBytes() []byte {
rawBytes, _ := json.Marshal(cs)
return rawBytes
}
func newCachedResult(searchUUID, query, result string) *cachedResult {
return &cachedResult{
SearchUUID: searchUUID,
Query: query,
Result: result,
}
}