diff --git a/chat/service.go b/chat/service.go index 1bc576c..efba524 100644 --- a/chat/service.go +++ b/chat/service.go @@ -46,25 +46,6 @@ func (api *Api) responseFunc(c *gin.Context, startTime time.Time, status, msg st c.JSON(httpStatus, ar) } -func (api *Api) wsCheckConnectStatus(conn *websocket.Conn, chClose chan int) { - var err error - defer func() { - conn.Close() - }() - conn.SetReadDeadline(time.Now().Add(pingWait)) - conn.SetPongHandler(func(s string) error { - conn.SetReadDeadline(time.Now().Add(pingWait)) - return nil - }) - for { - _, _, err = conn.ReadMessage() - if err != nil { - chClose <- 0 - return - } - } -} - func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) { var err error ticker := time.NewTicker(pingPeriod) @@ -96,7 +77,6 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *sync.Mutex, requestMsg string) { var err error var strResp string - var end bool req := gogpt.CompletionRequest{ Model: gogpt.GPT3TextDavinci003, MaxTokens: api.Config.MaxLength, @@ -109,23 +89,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s PresencePenalty: 0.1, } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(api.Config.TimeoutSeconds)) - defer func() { - if !end { - err = fmt.Errorf("[ERROR] context timeout") - chatMsg := Message{ - Kind: "error", - Msg: err.Error(), - MsgId: uuid.New().String(), - CreateTime: time.Now().Format("2006-01-02 15:04:05"), - } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - api.Logger.LogError(err.Error()) - } - cancel() - }() + ctx := context.Background() stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { @@ -149,7 +113,6 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s for { response, err := stream.Recv() if errors.Is(err, io.EOF) { - end = true var s string var kind string if i == 0 { @@ -173,7 +136,6 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s } break } else if err != nil { - end = true err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) chatMsg := Message{ Kind: "error", @@ -231,7 +193,7 @@ func (api *Api) WsChat(c *gin.Context) { mutex := &sync.Mutex{} conn, err := wsupgrader.Upgrade(c.Writer, c.Request, nil) if err != nil { - err = fmt.Errorf("failed to upgrade websocket %s", err.Error()) + err = fmt.Errorf("[ERROR] failed to upgrade websocket %s", err.Error()) msg = err.Error() api.responseFunc(c, startTime, status, msg, httpStatus, data) return @@ -240,13 +202,18 @@ func (api *Api) WsChat(c *gin.Context) { _ = conn.Close() }() + _ = conn.SetReadDeadline(time.Now().Add(pingWait)) + conn.SetPongHandler(func(s string) error { + _ = conn.SetReadDeadline(time.Now().Add(pingWait)) + return nil + }) + var isClosed bool chClose := make(chan int) chIsCloseSet := make(chan int) defer func() { conn.Close() }() - go api.wsCheckConnectStatus(conn, chClose) go api.wsPingMsg(conn, chClose, chIsCloseSet) go func() { for { @@ -269,7 +236,7 @@ func (api *Api) WsChat(c *gin.Context) { // read in a message messageType, bs, err := conn.ReadMessage() if err != nil { - err = fmt.Errorf("read message error: %s", err.Error()) + err = fmt.Errorf("[ERROR] read message error: %s", err.Error()) api.Logger.LogError(err.Error()) return } @@ -283,7 +250,7 @@ func (api *Api) WsChat(c *gin.Context) { ok = true } else { if time.Since(latestRequestTime) < time.Second*time.Duration(api.Config.IntervalSeconds) { - err = fmt.Errorf("please wait %d seconds for next query", api.Config.IntervalSeconds) + err = fmt.Errorf("[ERROR] please wait %d seconds for next query", api.Config.IntervalSeconds) chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -301,7 +268,7 @@ func (api *Api) WsChat(c *gin.Context) { } if ok { if len(strings.Trim(requestMsg, " ")) < 2 { - err = fmt.Errorf("message too short") + err = fmt.Errorf("[ERROR] message too short") chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -329,9 +296,13 @@ func (api *Api) WsChat(c *gin.Context) { isClosed = true api.Logger.LogInfo("[CLOSED] websocket receive closed message") case websocket.PingMessage: + _ = conn.SetReadDeadline(time.Now().Add(pingWait)) api.Logger.LogInfo("[PING] websocket receive ping message") + case websocket.PongMessage: + _ = conn.SetReadDeadline(time.Now().Add(pingWait)) + api.Logger.LogInfo("[PONG] websocket receive pong message") default: - err = fmt.Errorf("websocket receive message type not text") + err = fmt.Errorf("[ERROR] websocket receive message type not text") chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -341,7 +312,7 @@ func (api *Api) WsChat(c *gin.Context) { mutex.Lock() _ = conn.WriteJSON(chatMsg) mutex.Unlock() - api.Logger.LogError("websocket receive message type not text") + api.Logger.LogError(err.Error()) return } } diff --git a/chat/types.go b/chat/types.go index d83a619..b8e0484 100644 --- a/chat/types.go +++ b/chat/types.go @@ -6,5 +6,4 @@ type Config struct { IntervalSeconds int `yaml:"intervalSeconds" json:"intervalSeconds" bson:"intervalSeconds" validate:"required"` MaxLength int `yaml:"maxLength" json:"maxLength" bson:"maxLength" validate:"required"` Cors bool `yaml:"cors" json:"cors" bson:"cors" validate:""` - TimeoutSeconds int `yaml:"timeoutSeconds" json:"timeoutSeconds" bson:"timeoutSeconds" validate:"required"` } diff --git a/config.yaml b/config.yaml index baea4d8..9883b50 100644 --- a/config.yaml +++ b/config.yaml @@ -8,6 +8,4 @@ intervalSeconds: 5 maxLength: 2000 # 是否允许cors跨域 cors: true -# 问题反馈的超时时间,单位:秒 -timeoutSeconds: 300