diff --git a/chat/service.go b/chat/service.go index 0fea5e1..e34aeba 100644 --- a/chat/service.go +++ b/chat/service.go @@ -6,7 +6,6 @@ import ( "fmt" "io" "net/http" - "runtime" "strings" "sync" "time" @@ -108,7 +107,9 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s PresencePenalty: 0.1, } - ctx := context.Background() + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(time.Second*time.Duration(api.Config.TimeoutSeconds))) + defer cancel() + stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { err = fmt.Errorf("[ERROR] create chatGPT stream error: %s", err.Error()) @@ -153,8 +154,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s api.Logger.LogError(s) } break - } - if err != nil { + } else if err != nil { err = fmt.Errorf("[ERROR] receive chatGPT stream error: %s", err.Error()) chatMsg := Message{ Kind: "error", @@ -193,7 +193,6 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s if strResp != "" { api.Logger.LogInfo(fmt.Sprintf("[RESPONSE] %s", strResp)) } - runtime.GC() } func (api *Api) WsChat(c *gin.Context) { diff --git a/chat/types.go b/chat/types.go index b8e0484..d83a619 100644 --- a/chat/types.go +++ b/chat/types.go @@ -6,4 +6,5 @@ type Config struct { IntervalSeconds int `yaml:"intervalSeconds" json:"intervalSeconds" bson:"intervalSeconds" validate:"required"` MaxLength int `yaml:"maxLength" json:"maxLength" bson:"maxLength" validate:"required"` Cors bool `yaml:"cors" json:"cors" bson:"cors" validate:""` + TimeoutSeconds int `yaml:"timeoutSeconds" json:"timeoutSeconds" bson:"timeoutSeconds" validate:"required"` } diff --git a/config.yaml b/config.yaml index de0398c..2e0f83d 100644 --- a/config.yaml +++ b/config.yaml @@ -5,6 +5,9 @@ port: 9000 # 问题发送的时间间隔不能小于多长时间,单位:秒 intervalSeconds: 5 # 返回答案的最大长度 -maxLength: 1500 +maxLength: 2000 # 是否允许cors跨域 cors: true +# 问题反馈的超时时间,单位:秒 +timeoutSeconds: 5 +