支持上下文对话方式追加问题

This commit is contained in:
cookeem
2023-03-27 12:00:57 +08:00
parent 5b75b51059
commit 3af6192633
2 changed files with 18 additions and 15 deletions

View File

@@ -76,7 +76,7 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int)
}
}
func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) {
func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, reqMsgs []openai.ChatCompletionMessage) {
var err error
var strResp string
@@ -84,16 +84,12 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
switch api.Config.Model {
case openai.GPT3Dot5Turbo0301, openai.GPT3Dot5Turbo, openai.GPT4, openai.GPT40314, openai.GPT432K0314, openai.GPT432K:
prompt := reqMsgs[len(reqMsgs)-1].Content
req := openai.ChatCompletionRequest{
Model: api.Config.Model,
MaxTokens: api.Config.MaxLength,
Temperature: 1.0,
Messages: []openai.ChatCompletionMessage{
{
Role: openai.ChatMessageRoleUser,
Content: requestMsg,
},
},
Model: api.Config.Model,
MaxTokens: api.Config.MaxLength,
Temperature: 1.0,
Messages: reqMsgs,
Stream: true,
TopP: 1,
FrequencyPenalty: 0.1,
@@ -151,7 +147,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
if len(response.Choices) > 0 {
var s string
if i == 0 {
s = fmt.Sprintf(`%s# %s`, s, requestMsg)
s = fmt.Sprintf("%s# %s\n\n", s, prompt)
}
for _, choice := range response.Choices {
s = s + choice.Delta.Content
@@ -173,11 +169,12 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
api.Logger.LogInfo(fmt.Sprintf("[RESPONSE] %s\n", strResp))
}
case openai.GPT3TextDavinci003, openai.GPT3TextDavinci002, openai.GPT3TextCurie001, openai.GPT3TextBabbage001, openai.GPT3TextAda001, openai.GPT3TextDavinci001, openai.GPT3DavinciInstructBeta, openai.GPT3Davinci, openai.GPT3CurieInstructBeta, openai.GPT3Curie, openai.GPT3Ada, openai.GPT3Babbage:
prompt := reqMsgs[len(reqMsgs)-1].Content
req := openai.CompletionRequest{
Model: api.Config.Model,
MaxTokens: api.Config.MaxLength,
Temperature: 0.6,
Prompt: requestMsg,
Prompt: prompt,
Stream: true,
//Stop: []string{"\n\n\n"},
TopP: 1,
@@ -236,7 +233,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
if len(response.Choices) > 0 {
var s string
if i == 0 {
s = fmt.Sprintf(`%s# %s`, s, requestMsg)
s = fmt.Sprintf("%s# %s\n\n", s, prompt)
}
for _, choice := range response.Choices {
s = s + choice.Text
@@ -393,6 +390,8 @@ func (api *Api) WsChat(c *gin.Context) {
api.Logger.LogInfo(fmt.Sprintf("websocket connection open"))
cli := openai.NewClient(api.Config.ApiKey)
reqMsgs := make([]openai.ChatCompletionMessage, 0)
var latestRequestTime time.Time
for {
if isClosed {
@@ -466,7 +465,11 @@ func (api *Api) WsChat(c *gin.Context) {
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
go api.GetChatMessage(conn, cli, mutex, requestMsg)
reqMsgs = append(reqMsgs, openai.ChatCompletionMessage{
Role: openai.ChatMessageRoleUser,
Content: requestMsg,
})
go api.GetChatMessage(conn, cli, mutex, reqMsgs)
}
}
}

2
go.mod
View File

@@ -7,7 +7,7 @@ require (
github.com/gin-gonic/gin v1.8.2
github.com/google/uuid v1.3.0
github.com/gorilla/websocket v1.5.0
github.com/sashabaranov/go-openai v1.5.4
github.com/sashabaranov/go-openai v1.5.7
github.com/sirupsen/logrus v1.9.0
gopkg.in/yaml.v3 v3.0.1
)