From 5df6c9205024a5c6010f0f4ca807ed34029be409 Mon Sep 17 00:00:00 2001 From: cookeem Date: Mon, 13 Feb 2023 11:10:09 +0800 Subject: [PATCH] =?UTF-8?q?Revert=20"stream.Recv()=E4=BC=9A=E5=AF=BC?= =?UTF-8?q?=E8=87=B4cpu=20100=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C=E5=8D=87?= =?UTF-8?q?=E7=BA=A7gpt-3=E4=BE=9D=E8=B5=96=E5=BA=93"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit d7dfa7b216dac2a60d7996ba89858a4e05ccbdb9. --- chat/service.go | 99 +++++++++++++++++++++++++++++-------------------- chat/types.go | 1 + config.yaml | 3 ++ go.mod | 2 +- 4 files changed, 63 insertions(+), 42 deletions(-) diff --git a/chat/service.go b/chat/service.go index 7e2e982..42d3679 100644 --- a/chat/service.go +++ b/chat/service.go @@ -96,6 +96,7 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *sync.Mutex, requestMsg string) { var err error var strResp string + var end bool req := gogpt.CompletionRequest{ Model: gogpt.GPT3TextDavinci003, MaxTokens: api.Config.MaxLength, @@ -108,7 +109,23 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s PresencePenalty: 0.1, } - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(api.Config.TimeoutSeconds)) + defer func() { + if !end { + err = fmt.Errorf("[ERROR] context timeout") + chatMsg := Message{ + Kind: "error", + Msg: err.Error(), + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogError(err.Error()) + } + cancel() + }() stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { @@ -125,53 +142,53 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s api.Logger.LogError(err.Error()) return } - defer func() { - stream.Close() - }() + defer stream.Close() id := uuid.New().String() var i int for { response, err := stream.Recv() - if err != nil { - if errors.Is(err, io.EOF) { - var s string - var kind string - if i == 0 { - s = "[ERROR] NO RESPONSE, PLEASE RETRY" - kind = "retry" - } else { - s = "\n\n###### [END] ######" - kind = "chat" - } - chatMsg := Message{ - Kind: kind, - Msg: s, - MsgId: id, - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - if kind == "retry" { - api.Logger.LogError(s) - } - break + if errors.Is(err, io.EOF) { + end = true + var s string + var kind string + if i == 0 { + s = "[ERROR] NO RESPONSE, PLEASE RETRY" + kind = "retry" } else { - err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) - chatMsg := Message{ - Kind: "error", - Msg: err.Error(), - MsgId: uuid.New().String(), - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - api.Logger.LogError(err.Error()) - break + s = "\n\n###### [END] ######" + kind = "chat" } - } else if len(response.Choices) > 0 { + chatMsg := Message{ + Kind: kind, + Msg: s, + MsgId: id, + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + if kind == "retry" { + api.Logger.LogError(s) + } + break + } else if err != nil { + end = true + err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) + chatMsg := Message{ + Kind: "error", + Msg: err.Error(), + MsgId: id, + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogError(err.Error()) + break + } + + if len(response.Choices) > 0 { var s string if i == 0 { s = fmt.Sprintf(`%s# %s`, s, requestMsg) diff --git a/chat/types.go b/chat/types.go index b8e0484..d83a619 100644 --- a/chat/types.go +++ b/chat/types.go @@ -6,4 +6,5 @@ type Config struct { IntervalSeconds int `yaml:"intervalSeconds" json:"intervalSeconds" bson:"intervalSeconds" validate:"required"` MaxLength int `yaml:"maxLength" json:"maxLength" bson:"maxLength" validate:"required"` Cors bool `yaml:"cors" json:"cors" bson:"cors" validate:""` + TimeoutSeconds int `yaml:"timeoutSeconds" json:"timeoutSeconds" bson:"timeoutSeconds" validate:"required"` } diff --git a/config.yaml b/config.yaml index 5d1c5c0..ff62588 100644 --- a/config.yaml +++ b/config.yaml @@ -8,3 +8,6 @@ intervalSeconds: 5 maxLength: 2000 # 是否允许cors跨域 cors: true +# 问题反馈的超时时间,单位:秒 +timeoutSeconds: 180 + diff --git a/go.mod b/go.mod index b5fdd77..6bb71b7 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/gin-gonic/gin v1.8.2 github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/sashabaranov/go-gpt3 v1.0.1 + github.com/sashabaranov/go-gpt3 v1.0.0 github.com/sirupsen/logrus v1.9.0 gopkg.in/yaml.v3 v3.0.1 )