From d7dfa7b216dac2a60d7996ba89858a4e05ccbdb9 Mon Sep 17 00:00:00 2001 From: cookeem Date: Mon, 13 Feb 2023 11:05:52 +0800 Subject: [PATCH] =?UTF-8?q?stream.Recv()=E4=BC=9A=E5=AF=BC=E8=87=B4cpu=201?= =?UTF-8?q?00=E7=9A=84=E9=97=AE=E9=A2=98=EF=BC=8C=E5=8D=87=E7=BA=A7gpt-3?= =?UTF-8?q?=E4=BE=9D=E8=B5=96=E5=BA=93?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- chat/service.go | 99 ++++++++++++++++++++----------------------------- chat/types.go | 1 - config.yaml | 3 -- go.mod | 2 +- 4 files changed, 42 insertions(+), 63 deletions(-) diff --git a/chat/service.go b/chat/service.go index 42d3679..7e2e982 100644 --- a/chat/service.go +++ b/chat/service.go @@ -96,7 +96,6 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *sync.Mutex, requestMsg string) { var err error var strResp string - var end bool req := gogpt.CompletionRequest{ Model: gogpt.GPT3TextDavinci003, MaxTokens: api.Config.MaxLength, @@ -109,23 +108,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s PresencePenalty: 0.1, } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(api.Config.TimeoutSeconds)) - defer func() { - if !end { - err = fmt.Errorf("[ERROR] context timeout") - chatMsg := Message{ - Kind: "error", - Msg: err.Error(), - MsgId: uuid.New().String(), - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - api.Logger.LogError(err.Error()) - } - cancel() - }() + ctx := context.Background() stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { @@ -142,53 +125,53 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s api.Logger.LogError(err.Error()) return } - defer stream.Close() + defer func() { + stream.Close() + }() id := uuid.New().String() var i int for { response, err := stream.Recv() - if errors.Is(err, io.EOF) { - end = true - var s string - var kind string - if i == 0 { - s = "[ERROR] NO RESPONSE, PLEASE RETRY" - kind = "retry" + if err != nil { + if errors.Is(err, io.EOF) { + var s string + var kind string + if i == 0 { + s = "[ERROR] NO RESPONSE, PLEASE RETRY" + kind = "retry" + } else { + s = "\n\n###### [END] ######" + kind = "chat" + } + chatMsg := Message{ + Kind: kind, + Msg: s, + MsgId: id, + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + if kind == "retry" { + api.Logger.LogError(s) + } + break } else { - s = "\n\n###### [END] ######" - kind = "chat" + err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) + chatMsg := Message{ + Kind: "error", + Msg: err.Error(), + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogError(err.Error()) + break } - chatMsg := Message{ - Kind: kind, - Msg: s, - MsgId: id, - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - if kind == "retry" { - api.Logger.LogError(s) - } - break - } else if err != nil { - end = true - err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) - chatMsg := Message{ - Kind: "error", - Msg: err.Error(), - MsgId: id, - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - api.Logger.LogError(err.Error()) - break - } - - if len(response.Choices) > 0 { + } else if len(response.Choices) > 0 { var s string if i == 0 { s = fmt.Sprintf(`%s# %s`, s, requestMsg) diff --git a/chat/types.go b/chat/types.go index d83a619..b8e0484 100644 --- a/chat/types.go +++ b/chat/types.go @@ -6,5 +6,4 @@ type Config struct { IntervalSeconds int `yaml:"intervalSeconds" json:"intervalSeconds" bson:"intervalSeconds" validate:"required"` MaxLength int `yaml:"maxLength" json:"maxLength" bson:"maxLength" validate:"required"` Cors bool `yaml:"cors" json:"cors" bson:"cors" validate:""` - TimeoutSeconds int `yaml:"timeoutSeconds" json:"timeoutSeconds" bson:"timeoutSeconds" validate:"required"` } diff --git a/config.yaml b/config.yaml index ff62588..5d1c5c0 100644 --- a/config.yaml +++ b/config.yaml @@ -8,6 +8,3 @@ intervalSeconds: 5 maxLength: 2000 # 是否允许cors跨域 cors: true -# 问题反馈的超时时间,单位:秒 -timeoutSeconds: 180 - diff --git a/go.mod b/go.mod index 6bb71b7..b5fdd77 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/gin-gonic/gin v1.8.2 github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/sashabaranov/go-gpt3 v1.0.0 + github.com/sashabaranov/go-gpt3 v1.0.1 github.com/sirupsen/logrus v1.9.0 gopkg.in/yaml.v3 v3.0.1 )