From 3af6192633f052a0a96145a0f2079aa9784e19d3 Mon Sep 17 00:00:00 2001 From: cookeem Date: Mon, 27 Mar 2023 12:00:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81=E4=B8=8A=E4=B8=8B=E6=96=87?= =?UTF-8?q?=E5=AF=B9=E8=AF=9D=E6=96=B9=E5=BC=8F=E8=BF=BD=E5=8A=A0=E9=97=AE?= =?UTF-8?q?=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat/service.go | 31 +++++++++++++++++-------------- go.mod | 2 +- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/chat/service.go b/chat/service.go index 0b16c55..ed2f960 100644 --- a/chat/service.go +++ b/chat/service.go @@ -76,7 +76,7 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) } } -func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) { +func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, reqMsgs []openai.ChatCompletionMessage) { var err error var strResp string @@ -84,16 +84,12 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * switch api.Config.Model { case openai.GPT3Dot5Turbo0301, openai.GPT3Dot5Turbo, openai.GPT4, openai.GPT40314, openai.GPT432K0314, openai.GPT432K: + prompt := reqMsgs[len(reqMsgs)-1].Content req := openai.ChatCompletionRequest{ - Model: api.Config.Model, - MaxTokens: api.Config.MaxLength, - Temperature: 1.0, - Messages: []openai.ChatCompletionMessage{ - { - Role: openai.ChatMessageRoleUser, - Content: requestMsg, - }, - }, + Model: api.Config.Model, + MaxTokens: api.Config.MaxLength, + Temperature: 1.0, + Messages: reqMsgs, Stream: true, TopP: 1, FrequencyPenalty: 0.1, @@ -151,7 +147,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * if len(response.Choices) > 0 { var s string if i == 0 { - s = fmt.Sprintf(`%s# %s`, s, requestMsg) + s = fmt.Sprintf("%s# %s\n\n", s, prompt) } for _, choice := range response.Choices { s = s + choice.Delta.Content @@ -173,11 +169,12 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * api.Logger.LogInfo(fmt.Sprintf("[RESPONSE] %s\n", strResp)) } case openai.GPT3TextDavinci003, openai.GPT3TextDavinci002, openai.GPT3TextCurie001, openai.GPT3TextBabbage001, openai.GPT3TextAda001, openai.GPT3TextDavinci001, openai.GPT3DavinciInstructBeta, openai.GPT3Davinci, openai.GPT3CurieInstructBeta, openai.GPT3Curie, openai.GPT3Ada, openai.GPT3Babbage: + prompt := reqMsgs[len(reqMsgs)-1].Content req := openai.CompletionRequest{ Model: api.Config.Model, MaxTokens: api.Config.MaxLength, Temperature: 0.6, - Prompt: requestMsg, + Prompt: prompt, Stream: true, //Stop: []string{"\n\n\n"}, TopP: 1, @@ -236,7 +233,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * if len(response.Choices) > 0 { var s string if i == 0 { - s = fmt.Sprintf(`%s# %s`, s, requestMsg) + s = fmt.Sprintf("%s# %s\n\n", s, prompt) } for _, choice := range response.Choices { s = s + choice.Text @@ -393,6 +390,8 @@ func (api *Api) WsChat(c *gin.Context) { api.Logger.LogInfo(fmt.Sprintf("websocket connection open")) cli := openai.NewClient(api.Config.ApiKey) + reqMsgs := make([]openai.ChatCompletionMessage, 0) + var latestRequestTime time.Time for { if isClosed { @@ -466,7 +465,11 @@ func (api *Api) WsChat(c *gin.Context) { mutex.Lock() _ = conn.WriteJSON(chatMsg) mutex.Unlock() - go api.GetChatMessage(conn, cli, mutex, requestMsg) + reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{ + Role: openai.ChatMessageRoleUser, + Content: requestMsg, + }) + go api.GetChatMessage(conn, cli, mutex, reqMsgs) } } } diff --git a/go.mod b/go.mod index 1119277..b9bbf26 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/gin-gonic/gin v1.8.2 github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/sashabaranov/go-openai v1.5.4 + github.com/sashabaranov/go-openai v1.5.7 github.com/sirupsen/logrus v1.9.0 gopkg.in/yaml.v3 v3.0.1 )