first commit
This commit is contained in:
18
service/cache.go
Normal file
18
service/cache.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/coocood/freecache"
|
||||
)
|
||||
|
||||
var (
|
||||
searchCache *freecache.Cache
|
||||
)
|
||||
|
||||
func InitCache() {
|
||||
cacheSize := GetSettings().CacheSize * 1024 * 1024 // 100 MB
|
||||
searchCache = freecache.NewCache(cacheSize)
|
||||
}
|
||||
|
||||
func GetSearchCache() *freecache.Cache {
|
||||
return searchCache
|
||||
}
|
27
service/const.go
Normal file
27
service/const.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
const (
|
||||
HeaderKey = "header"
|
||||
AppEntryKey = "entry"
|
||||
UIDKey = "uid"
|
||||
EndpointKey = "endpoint"
|
||||
UserInfoKey = "userinfo"
|
||||
TokenKey = "token"
|
||||
AuthorizationKey = "authorization"
|
||||
SetAuthorizationKey = "x-authorization-token"
|
||||
TraceIDKey = "x-vaala-trace-id"
|
||||
ClientRequestIDKey = "x-client-request-id"
|
||||
SessionKey = "x-vaala-session"
|
||||
RegistrationSessionKey = "x-vaala-registration-session"
|
||||
LoginSessionKey = "x-vaala-login-session"
|
||||
)
|
||||
|
||||
const (
|
||||
RespSuccess = "success"
|
||||
ValueNone = "none"
|
||||
)
|
||||
|
||||
const (
|
||||
LangEN = "en"
|
||||
LangZH = "zh"
|
||||
)
|
129
service/helper.go
Normal file
129
service/helper.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/service/logger"
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func getOpenAIConfig(c context.Context, mode string) (token, endpoint string) {
|
||||
token = GetSettings().OpenAIAPIKey
|
||||
endpoint = GetSettings().OpenAIEndpint
|
||||
if mode == "chat" {
|
||||
token = GetSettings().OpenAIChatAPIKey
|
||||
endpoint = GetSettings().OpenAIChatEndpoint
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
type Queue struct {
|
||||
data []time.Duration
|
||||
length int
|
||||
}
|
||||
|
||||
func NewQueue(length int, defaultValue int) *Queue {
|
||||
data := make([]time.Duration, 0, length)
|
||||
for i := 0; i < length; i++ {
|
||||
data = append(data, time.Duration(defaultValue)*time.Millisecond)
|
||||
}
|
||||
return &Queue{
|
||||
data: data,
|
||||
length: length,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Queue) Add(value time.Duration) {
|
||||
if len(q.data) >= q.length {
|
||||
q.data = q.data[1:]
|
||||
}
|
||||
q.data = append(q.data, value)
|
||||
}
|
||||
|
||||
func (q *Queue) Avg(k ...int) time.Duration {
|
||||
param := 1
|
||||
if len(k) > 0 {
|
||||
param = k[0]
|
||||
}
|
||||
total := time.Duration(0)
|
||||
count := 0
|
||||
for _, value := range q.data {
|
||||
if value != 0 {
|
||||
total += value
|
||||
count++
|
||||
}
|
||||
}
|
||||
if count == 0 {
|
||||
return time.Duration(0)
|
||||
}
|
||||
ans := total * time.Duration(param) / time.Duration(count)
|
||||
if ans > time.Duration(20)*time.Millisecond {
|
||||
return time.Duration(20) * time.Millisecond
|
||||
}
|
||||
return ans
|
||||
}
|
||||
|
||||
func streamResp(c *gin.Context, resp *openai.ChatCompletionStream) string {
|
||||
result := ""
|
||||
|
||||
if resp == nil {
|
||||
logger.Logger(c).Error("stream resp is nil")
|
||||
return result
|
||||
}
|
||||
|
||||
_, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.Logger(c).Panic("server not support")
|
||||
}
|
||||
defer func() {
|
||||
c.Writer.Flush()
|
||||
}()
|
||||
|
||||
queue := NewQueue(GetSettings().OpenAIChatQueueLen, GetSettings().OpenAIChatNetworkDelay)
|
||||
ch := make(chan rune, 1024)
|
||||
|
||||
go func(c *gin.Context, msgChan chan rune) {
|
||||
lastTime := time.Now()
|
||||
for {
|
||||
line, err := resp.Recv()
|
||||
if err != nil {
|
||||
close(msgChan)
|
||||
logger.Logger(c).WithError(err).Error("read openai completion line error")
|
||||
return
|
||||
}
|
||||
if len(line.Choices[0].Delta.Content) == 0 {
|
||||
continue
|
||||
}
|
||||
nowTime := time.Now()
|
||||
|
||||
division := strings.Count(line.Choices[0].Delta.Content, "")
|
||||
for _, v := range line.Choices[0].Delta.Content {
|
||||
msgChan <- v
|
||||
}
|
||||
|
||||
during := (nowTime.Sub(lastTime) + (time.Duration(GetSettings().OpenAIChatNetworkDelay) *
|
||||
time.Millisecond)) / time.Duration(division)
|
||||
queue.Add(during)
|
||||
lastTime = nowTime
|
||||
}
|
||||
}(c, ch)
|
||||
|
||||
for char := range ch {
|
||||
str := string(char)
|
||||
_, err := c.Writer.WriteString(str)
|
||||
result += str
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Error("write string to client error")
|
||||
return ""
|
||||
}
|
||||
c.Writer.Flush()
|
||||
time.Sleep(queue.Avg(len(str) * 2)) // 英文平均长度为6个字符,一个UTF8字符是3个长度,试图让一个单词等于一个汉字
|
||||
}
|
||||
logger.Logger(c).Info("finish stream text to client")
|
||||
return result
|
||||
}
|
11
service/logger/logger.go
Normal file
11
service/logger/logger.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func Logger(c context.Context) *logrus.Entry {
|
||||
return logrus.WithContext(c)
|
||||
}
|
51
service/response.go
Normal file
51
service/response.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type CommonResp interface {
|
||||
gin.H
|
||||
}
|
||||
|
||||
func OKResp[T CommonResp](c *gin.Context, origin *T) {
|
||||
if c.ContentType() == "application/x-protobuf" {
|
||||
c.ProtoBuf(http.StatusOK, origin)
|
||||
} else {
|
||||
c.JSON(http.StatusOK, OK(RespSuccess).WithBody(origin))
|
||||
}
|
||||
}
|
||||
|
||||
func OKRespWithJsonMarshal[T CommonResp](c *gin.Context, origin *T) {
|
||||
c.JSON(http.StatusOK, OK(RespSuccess).WithBody(origin))
|
||||
}
|
||||
|
||||
func ErrResp[T CommonResp](c *gin.Context, origin *T, err string, errCode ...int) {
|
||||
if c.ContentType() == "application/x-protobuf" {
|
||||
c.ProtoBuf(http.StatusInternalServerError, origin)
|
||||
} else {
|
||||
if len(errCode) > 0 {
|
||||
c.JSON(errCode[0], Err(err).WithBody(origin))
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, Err(err).WithBody(origin))
|
||||
}
|
||||
}
|
||||
|
||||
func ErrUnAuthorized(c *gin.Context, err string) {
|
||||
if c.ContentType() == "application/x-protobuf" {
|
||||
c.ProtoBuf(http.StatusUnauthorized, nil)
|
||||
} else {
|
||||
c.JSON(http.StatusUnauthorized, Err(err).WithBody(nil))
|
||||
}
|
||||
}
|
||||
|
||||
func ErrNotFound(c *gin.Context, err string) {
|
||||
if c.ContentType() == "application/x-protobuf" {
|
||||
c.ProtoBuf(http.StatusNotFound, nil)
|
||||
} else {
|
||||
c.JSON(http.StatusNotFound, Err(err).WithBody(nil))
|
||||
}
|
||||
}
|
55
service/result.go
Normal file
55
service/result.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Result struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Msg string `json:"msg,omitempty"`
|
||||
Data gin.H `json:"data,omitempty"`
|
||||
Body interface{} `json:"body,omitempty"`
|
||||
}
|
||||
|
||||
func (r *Result) WithMsg(message string) *Result {
|
||||
r.Msg = message
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Result) WithData(data gin.H) *Result {
|
||||
r.Data = data
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Result) WithKeyValue(key string, value interface{}) *Result {
|
||||
if r.Data == nil {
|
||||
r.Data = gin.H{}
|
||||
}
|
||||
r.Data[key] = value
|
||||
return r
|
||||
}
|
||||
|
||||
func (r *Result) WithBody(body interface{}) *Result {
|
||||
r.Body = body
|
||||
return r
|
||||
}
|
||||
|
||||
func newResult(code int, msg string) *Result {
|
||||
return &Result{
|
||||
Code: code,
|
||||
Msg: msg,
|
||||
Data: nil,
|
||||
}
|
||||
}
|
||||
|
||||
func OK(msg string) *Result {
|
||||
return newResult(200, msg)
|
||||
}
|
||||
|
||||
func Err(msg string) *Result {
|
||||
return newResult(500, msg)
|
||||
}
|
||||
|
||||
func UnAuth(msg string) *Result {
|
||||
return newResult(401, msg)
|
||||
}
|
51
service/rpc.go
Normal file
51
service/rpc.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/service/logger"
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
)
|
||||
|
||||
type SearchClient interface {
|
||||
Search(c context.Context, query string, nums int) (SearchResp, error)
|
||||
}
|
||||
|
||||
type searchClient struct {
|
||||
URL string
|
||||
}
|
||||
|
||||
type SearchResp struct {
|
||||
Results []SearchResult `json:"results"`
|
||||
}
|
||||
|
||||
type SearchResult struct {
|
||||
Body string `json:"body"`
|
||||
Href string `json:"href"`
|
||||
Title string `json:"title"`
|
||||
}
|
||||
|
||||
func NewSearchClient() SearchClient {
|
||||
return &searchClient{
|
||||
URL: GetSettings().RPCEndpoints.SearchURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *searchClient) Search(c context.Context, query string, nums int) (SearchResp, error) {
|
||||
resp := SearchResp{}
|
||||
_, err := req.C().R().
|
||||
SetFormData(map[string]string{
|
||||
"q": query,
|
||||
"max_results": fmt.Sprintf("%d", nums),
|
||||
}).
|
||||
SetContentType("application/x-www-form-urlencoded").
|
||||
SetSuccessResult(&resp).
|
||||
Post(s.URL)
|
||||
|
||||
if err != nil {
|
||||
logger.Logger(c).Error(err)
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
214
service/search_handler.go
Normal file
214
service/search_handler.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/service/logger"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/sashabaranov/go-openai"
|
||||
)
|
||||
|
||||
func SearchHandler(c *gin.Context) {
|
||||
searchReq := &SearchReq{}
|
||||
if err := c.Copy().ShouldBindJSON(searchReq); err != nil {
|
||||
ErrResp[gin.H](c, nil, "error", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
cachedResult, err := getCacheResp(c, searchReq.SearchUUID)
|
||||
if err == nil {
|
||||
logger.Logger(c).Infof("cache key hit [%s], query: [%s]", searchReq.SearchUUID, searchReq.Query)
|
||||
c.String(http.StatusOK, cachedResult)
|
||||
return
|
||||
}
|
||||
|
||||
if searchReq.Query == "" && searchReq.SearchUUID == "" {
|
||||
ErrResp[gin.H](c, nil, "param is invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if searchReq.Query == "" && searchReq.SearchUUID != "" {
|
||||
ErrResp[gin.H](c, nil, "content is gone", http.StatusGone)
|
||||
return
|
||||
}
|
||||
|
||||
c.Writer.Header().Set("Content-Type", "text/event-stream")
|
||||
c.Writer.Header().Set("Cache-Control", "no-cache")
|
||||
c.Writer.Header().Set("Connection", "keep-alive")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
cli := NewSearchClient()
|
||||
searchResp, err := cli.Search(c, searchReq.Query, GetSettings().RAGSearchCount)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.Search error")
|
||||
return
|
||||
}
|
||||
ss := &Sources{}
|
||||
ss.FromSearchResp(&searchResp, searchReq.Query, searchReq.SearchUUID)
|
||||
|
||||
originReq := &openai.ChatCompletionRequest{
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleSystem,
|
||||
Content: fmt.Sprintf(RagPrompt(), getSearchContext(ss)),
|
||||
},
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: searchReq.Query,
|
||||
},
|
||||
},
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
||||
conf := openai.DefaultConfig(apiKey)
|
||||
conf.BaseURL = endpoint
|
||||
client := openai.NewClientWithConfig(conf)
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: originReq.Messages,
|
||||
Temperature: GetSettings().RAGParams.Temperature,
|
||||
MaxTokens: GetSettings().RAGParams.MaxTokens,
|
||||
Stream: true,
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletionStream(
|
||||
context.Background(),
|
||||
request,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletionStream error")
|
||||
}
|
||||
|
||||
relatedStrChan := make(chan string)
|
||||
defer close(relatedStrChan)
|
||||
go func() {
|
||||
relatedStrChan <- getRelatedQuestionsResp(c, searchReq.Query, ss)
|
||||
}()
|
||||
|
||||
finalResult := streamSearchItemResp(c, []string{
|
||||
ss.ToString(),
|
||||
"\n\n__LLM_RESPONSE__\n\n",
|
||||
})
|
||||
finalResult = finalResult + streamResp(c, resp)
|
||||
finalResult = finalResult + streamSearchItemResp(c, []string{
|
||||
"\n\n__RELATED_QUESTIONS__\n\n",
|
||||
// `[{"question": "What is the formal way to say hello in Chinese?"}, {"question": "How do you say 'How are you' in Chinese?"}]`,
|
||||
<-relatedStrChan,
|
||||
})
|
||||
|
||||
GetSearchCache().Set([]byte(searchReq.SearchUUID), newCachedResult(searchReq.SearchUUID, searchReq.Query, finalResult).ToBytes(), GetSettings().RAGSearchCacheTime)
|
||||
logger.Logger(c).Infof("cache key miss [%s], query: [%s], set result to cache", searchReq.SearchUUID, searchReq.Query)
|
||||
}
|
||||
|
||||
func getCacheResp(c *gin.Context, searchUUID string) (string, error) {
|
||||
ans, err := GetSearchCache().Get([]byte(searchUUID))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if ans != nil {
|
||||
cachedResult := &cachedResult{}
|
||||
cachedResult.FromBytes(ans)
|
||||
return cachedResult.Result, nil
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("cache not found")
|
||||
}
|
||||
|
||||
func streamSearchItemResp(c *gin.Context, t []string) string {
|
||||
result := ""
|
||||
|
||||
_, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
logger.Logger(c).Panic("server not support")
|
||||
}
|
||||
defer func() {
|
||||
c.Writer.Flush()
|
||||
}()
|
||||
|
||||
for _, line := range t {
|
||||
_, err := c.Writer.WriteString(line)
|
||||
result += line
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Error("write string to client error")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
logger.Logger(c).Info("finish stream text to client")
|
||||
return result
|
||||
}
|
||||
|
||||
func getRelatedQuestionsResp(c context.Context, query string, ss *Sources) string {
|
||||
apiKey, endpoint := getOpenAIConfig(c, "chat")
|
||||
conf := openai.DefaultConfig(apiKey)
|
||||
conf.BaseURL = endpoint
|
||||
client := openai.NewClientWithConfig(conf)
|
||||
request := openai.ChatCompletionRequest{
|
||||
Model: openai.GPT3Dot5Turbo,
|
||||
Messages: []openai.ChatCompletionMessage{
|
||||
{
|
||||
Role: openai.ChatMessageRoleUser,
|
||||
Content: fmt.Sprintf(MoreQuestionsPrompt(), getSearchContext(ss)) + query,
|
||||
},
|
||||
},
|
||||
Temperature: GetSettings().RAGParams.MoreQuestionsTemperature,
|
||||
MaxTokens: GetSettings().RAGParams.MoreQuestionsMaxTokens,
|
||||
}
|
||||
|
||||
resp, err := client.CreateChatCompletion(
|
||||
context.Background(),
|
||||
request,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Logger(c).WithError(err).Errorf("client.CreateChatCompletion error")
|
||||
}
|
||||
|
||||
mode := 1
|
||||
|
||||
cs := strings.Split(resp.Choices[0].Message.Content, ". ")
|
||||
if len(cs) == 1 {
|
||||
cs = strings.Split(resp.Choices[0].Message.Content, "- ")
|
||||
mode = 2
|
||||
}
|
||||
rq := []string{}
|
||||
for i, line := range cs {
|
||||
if len(line) <= 2 {
|
||||
continue
|
||||
}
|
||||
if i != len(cs)-1 && mode == 1 {
|
||||
line = line[:len(line)-1]
|
||||
}
|
||||
rq = append(rq, line)
|
||||
}
|
||||
|
||||
return parseRelatedQuestionsResp(rq)
|
||||
}
|
||||
|
||||
func parseRelatedQuestionsResp(qs []string) string {
|
||||
q := []struct {
|
||||
Question string `json:"question"`
|
||||
}{}
|
||||
for _, line := range qs {
|
||||
if len(strings.Trim(line, " ")) <= 2 {
|
||||
continue
|
||||
}
|
||||
q = append(q, struct {
|
||||
Question string `json:"question"`
|
||||
}{line})
|
||||
}
|
||||
rawBytes, _ := json.Marshal(q)
|
||||
return string(rawBytes)
|
||||
}
|
||||
|
||||
func getSearchContext(ss *Sources) string {
|
||||
ans := ""
|
||||
for i, ctx := range ss.Contexts {
|
||||
ans = ans + fmt.Sprintf("[[citation:%d]] ", i+1) + ctx.Snippet + "\n\n"
|
||||
}
|
||||
return ans
|
||||
}
|
69
service/settings.go
Normal file
69
service/settings.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"ai-search/utils"
|
||||
|
||||
"github.com/ilyakaznacheev/cleanenv"
|
||||
"github.com/joho/godotenv"
|
||||
)
|
||||
|
||||
type AppSettings struct {
|
||||
ListenAddr string `env:"LISTEN_ADDR" env-default:":8080"`
|
||||
DBPath string `env:"DB_PATH" env-default:"/litefs/db"`
|
||||
OpenAIEndpint string `env:"OPENAI_ENDPOINT" env-required:"true"`
|
||||
OpenAIAPIKey string `env:"OPENAI_API_KEY" env-required:"true"`
|
||||
OpenAIWriterPrompt string `env:"OPENAI_WRITER_PROMPT" env-default:"您是一名人工智能写作助手,可以根据先前文本的上下文继续现有文本。给予后面的字符比开始的字符更多的权重/优先级。而且绝对一定要尽量多使用中文!将您的回答限制在 200 个字符以内,但请确保构建完整的句子。"`
|
||||
HttpProxy string `env:"HTTP_PROXY"`
|
||||
IsDebug bool `env:"DEBUG" env-default:"false"`
|
||||
ChatMaxTokens int `env:"CHAT_MAX_TOKENS" env-default:"400"`
|
||||
OpenAIChatEndpoint string `env:"OPENAI_CHAT_ENDPOINT" env-required:"true"`
|
||||
OpenAIChatAPIKey string `env:"OPENAI_CHAT_API_KEY" env-required:"true"`
|
||||
OpenAIChatNetworkDelay int `env:"OPENAI_CHAT_NETWORK_DELAY" env-default:"5"`
|
||||
OpenAIChatQueueLen int `env:"OPENAI_CHAT_QUEUE_LEN" env-default:"10"`
|
||||
CacheSize int `env:"CACHE_SIZE" env-default:"100"` // in MB
|
||||
RAGSearchCount int `env:"RAG_SEARCH_COUNT" env-default:"8"`
|
||||
RAGSearchCacheTime int `env:"RAG_SEARCH_CACHE_TIME" env-default:"1200"` // sec
|
||||
RPCEndpoints RPCEndpoints `env-prefix:"RPC_"`
|
||||
Prompts Prompts `env-prefix:"PROMPT_"`
|
||||
RAGParams RAGParams `env-prefix:"RAG_"`
|
||||
}
|
||||
|
||||
type RPCEndpoints struct {
|
||||
SearchURL string `env:"SEARCH_URL" env-required:"true"`
|
||||
}
|
||||
|
||||
type Prompts struct {
|
||||
RAGPath string `env:"RAG_PATH" env-default:""`
|
||||
MoreQuestionsPath string `env:"MORE_QUESTIONS_PATH" env-default:""`
|
||||
}
|
||||
|
||||
type RAGParams struct {
|
||||
MaxTokens int `env:"MAX_TOKENS" env-default:"2048"`
|
||||
MoreQuestionsMaxTokens int `env:"MORE_QUESTIONS_MAX_TOKENS" env-default:"1024"`
|
||||
Temperature float32 `env:"TEMPERATURE" env-default:"0.9"`
|
||||
MoreQuestionsTemperature float32 `env:"MORE_QUESTIONS_TEMPERATURE" env-default:"0.7"`
|
||||
}
|
||||
|
||||
var (
|
||||
appSetting *AppSettings
|
||||
)
|
||||
|
||||
func init() {
|
||||
godotenv.Load()
|
||||
conf := &AppSettings{}
|
||||
if err := cleanenv.ReadEnv(conf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
appSetting = conf
|
||||
InitCache()
|
||||
if len(appSetting.Prompts.MoreQuestionsPath) >= 0 {
|
||||
moreQuestionsPrompt = utils.GetFileContent(appSetting.Prompts.MoreQuestionsPath)
|
||||
}
|
||||
if len(appSetting.Prompts.RAGPath) >= 0 {
|
||||
ragPrompt = utils.GetFileContent(appSetting.Prompts.RAGPath)
|
||||
}
|
||||
}
|
||||
|
||||
func GetSettings() *AppSettings {
|
||||
return appSetting
|
||||
}
|
51
service/static.go
Normal file
51
service/static.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-contrib/static"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type embedFileSystem struct {
|
||||
http.FileSystem
|
||||
}
|
||||
|
||||
func (e embedFileSystem) Exists(prefix string, path string) bool {
|
||||
_, err := e.Open(path)
|
||||
return err == nil
|
||||
}
|
||||
func EmbedFolder(fsEmbed embed.FS, targetPath string) static.ServeFileSystem {
|
||||
fsys, err := fs.Sub(fsEmbed, targetPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return embedFileSystem{
|
||||
FileSystem: http.FS(fsys),
|
||||
}
|
||||
}
|
||||
|
||||
func HandleStaticFile(router *gin.Engine, f embed.FS) {
|
||||
root := EmbedFolder(f, "out")
|
||||
router.Use(static.Serve("/", root))
|
||||
staticServer := static.Serve("/", root)
|
||||
router.NoRoute(func(c *gin.Context) {
|
||||
if c.Request.Method == http.MethodGet &&
|
||||
!strings.ContainsRune(c.Request.URL.Path, '.') &&
|
||||
!strings.HasPrefix(c.Request.URL.Path, "/v1/") {
|
||||
if strings.HasSuffix(c.Request.URL.Path, "/") {
|
||||
c.Request.URL.Path += "index.html"
|
||||
staticServer(c)
|
||||
return
|
||||
}
|
||||
if !strings.HasSuffix(c.Request.URL.Path, ".html") {
|
||||
c.Request.URL.Path += ".html"
|
||||
staticServer(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
126
service/types.go
Normal file
126
service/types.go
Normal file
@@ -0,0 +1,126 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ragPrompt = `You are a large language AI assistant built by VaalaCat. You are given a user question, and please write clean, concise and accurate answer to the question. You will be given a set of related contexts to the question, each starting with a reference number like [[citation:x]], where x is a number. Please use the context and cite the context at the end of each sentence if applicable.
|
||||
|
||||
Your answer must be correct, accurate and written by an expert using an unbiased and professional tone. Please limit to 1024 tokens. Do not give any information that is not related to the question, and do not repeat. Say "information is missing on" followed by the related topic, if the given context do not provide sufficient information.
|
||||
|
||||
Please cite the contexts with the reference numbers, in the format [citation:x]. If a sentence comes from multiple contexts, please list all applicable citations, like [citation:3][citation:5]. Other than code and specific names and citations, your answer must be written in the same language as the question.
|
||||
|
||||
Here are the set of contexts:
|
||||
|
||||
%s
|
||||
|
||||
Remember, use Chinese more, don't blindly repeat the contexts verbatim. And here is the user question:
|
||||
`
|
||||
moreQuestionsPrompt = `You are a helpful assistant that helps the user to ask related questions, based on user's original question and the related contexts. Please identify worthwhile topics that can be follow-ups, and write questions no longer than 20 words each. Please make sure that specifics, like events, names, locations, are included in follow up questions so they can be asked standalone. For example, if the original question asks about "the Manhattan project", in the follow up question, do not just say "the project", but use the full name "the Manhattan project". Your related questions must be in the same language as the original question.
|
||||
|
||||
Here are the contexts of the question:
|
||||
|
||||
%s
|
||||
|
||||
Remember, use Chinese more, based on the original question and related contexts, suggest three such further questions. Do NOT repeat the original question. Each related question should be no longer than 20 words. Here is the original question:
|
||||
`
|
||||
)
|
||||
|
||||
func RagPrompt() string {
|
||||
return ragPrompt
|
||||
}
|
||||
|
||||
func MoreQuestionsPrompt() string {
|
||||
return moreQuestionsPrompt
|
||||
}
|
||||
|
||||
type SearchReq struct {
|
||||
Query string `json:"query"`
|
||||
SearchUUID string `json:"search_uuid"`
|
||||
}
|
||||
|
||||
type Sources struct {
|
||||
Query string `json:"query"`
|
||||
RID string `json:"rid"`
|
||||
Contexts []Source `json:"contexts"`
|
||||
}
|
||||
|
||||
func (ss *Sources) FromSearchResp(resp *SearchResp, query, rid string) {
|
||||
ss.Query = query
|
||||
ss.RID = rid
|
||||
ctxs := make([]Source, 0)
|
||||
for _, ctx := range resp.Results {
|
||||
ctxs = append(ctxs, Source{
|
||||
ID: ctx.Href,
|
||||
Name: ctx.Title,
|
||||
URL: ctx.Href,
|
||||
Snippet: ctx.Body,
|
||||
})
|
||||
}
|
||||
ss.Contexts = ctxs
|
||||
}
|
||||
|
||||
func (ss *Sources) ToString() string {
|
||||
rawBytes, _ := json.Marshal(ss)
|
||||
return string(rawBytes)
|
||||
}
|
||||
|
||||
type Source struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
IsFamilyFriendly bool `json:"isFamilyFriendly"`
|
||||
DisplayURL string `json:"displayUrl"`
|
||||
Snippet string `json:"snippet"`
|
||||
DeepLinks []DeepLink `json:"deepLinks"`
|
||||
DateLastCrawled time.Time `json:"dateLastCrawled"`
|
||||
CachedPageURL string `json:"cachedPageUrl"`
|
||||
Language string `json:"language"`
|
||||
PrimaryImageOfPage *PrimaryImage `json:"primaryImageOfPage,omitempty"`
|
||||
IsNavigational bool `json:"isNavigational"`
|
||||
}
|
||||
|
||||
type DeepLink struct {
|
||||
Snippet string `json:"snippet"`
|
||||
Name string `json:"name"`
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PrimaryImage struct {
|
||||
ThumbnailURL string `json:"thumbnailUrl"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
ImageID string `json:"imageId"`
|
||||
}
|
||||
|
||||
func (s *Source) FromSearchResp(resp *SearchResult) {
|
||||
s.ID = resp.Href
|
||||
s.Name = resp.Title
|
||||
s.URL = resp.Href
|
||||
s.Snippet = resp.Body
|
||||
}
|
||||
|
||||
type cachedResult struct {
|
||||
SearchUUID string `json:"search_uuid"`
|
||||
Query string `json:"query"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
func (cs *cachedResult) FromBytes(rawBytes []byte) {
|
||||
json.Unmarshal(rawBytes, cs)
|
||||
}
|
||||
|
||||
func (cs *cachedResult) ToBytes() []byte {
|
||||
rawBytes, _ := json.Marshal(cs)
|
||||
return rawBytes
|
||||
}
|
||||
|
||||
func newCachedResult(searchUUID, query, result string) *cachedResult {
|
||||
return &cachedResult{
|
||||
SearchUUID: searchUUID,
|
||||
Query: query,
|
||||
Result: result,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user