diff --git a/chat/service.go b/chat/service.go index e34aeba..42d3679 100644 --- a/chat/service.go +++ b/chat/service.go @@ -94,7 +94,9 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) } func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *sync.Mutex, requestMsg string) { + var err error var strResp string + var end bool req := gogpt.CompletionRequest{ Model: gogpt.GPT3TextDavinci003, MaxTokens: api.Config.MaxLength, @@ -107,8 +109,23 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s PresencePenalty: 0.1, } - ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*time.Duration(api.Config.TimeoutSeconds))) - defer cancel() + ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(api.Config.TimeoutSeconds)) + defer func() { + if !end { + err = fmt.Errorf("[ERROR] context timeout") + chatMsg := Message{ + Kind: "error", + Msg: err.Error(), + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogError(err.Error()) + } + cancel() + }() stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { @@ -132,6 +149,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s for { response, err := stream.Recv() if errors.Is(err, io.EOF) { + end = true var s string var kind string if i == 0 { @@ -155,6 +173,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s } break } else if err != nil { + end = true err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) chatMsg := Message{ Kind: "error", @@ -294,7 +313,6 @@ func (api *Api) WsChat(c *gin.Context) { mutex.Unlock() api.Logger.LogError(err.Error()) } else { - chatMsg := Message{ Kind: "receive", Msg: requestMsg,