diff --git a/chat/service.go b/chat/service.go index 435f677..2bec4c7 100644 --- a/chat/service.go +++ b/chat/service.go @@ -2,7 +2,9 @@ package chat import ( "context" + "errors" "fmt" + "io" "net/http" "strings" "sync" @@ -11,7 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" - gogpt "github.com/sashabaranov/go-gpt3" + openai "github.com/sashabaranov/go-openai" ) type Api struct { @@ -72,16 +74,17 @@ func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) } } -func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *sync.Mutex, requestMsg string) { +func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) { var err error var strResp string - req := gogpt.ChatCompletionRequest{ - Model: gogpt.GPT3Dot5Turbo0301, + model := openai.GPT3Dot5Turbo0301 + req := openai.ChatCompletionRequest{ + Model: model, MaxTokens: api.Config.MaxLength, Temperature: 1.0, - Messages: []gogpt.ChatCompletionMessage{ + Messages: []openai.ChatCompletionMessage{ { - Role: "user", + Role: openai.ChatMessageRoleUser, Content: requestMsg, }, }, @@ -95,7 +98,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s stream, err := cli.CreateChatCompletionStream(ctx, req) if err != nil { - err = fmt.Errorf("[ERROR] create chatGPT stream error: %s", err.Error()) + err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", model, err.Error()) chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -117,12 +120,17 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *gogpt.Client, mutex *s if err != nil { var s string var kind string - if i == 0 { - s = "[ERROR] NO RESPONSE, PLEASE RETRY" - kind = "retry" + if errors.Is(err, io.EOF) { + if i == 0 { + s = "[ERROR] NO RESPONSE, PLEASE RETRY" + kind = "retry" + } else { + s = "\n\n###### [END] ######" + kind = "chat" + } } else { - s = "\n\n###### [END] ######" - kind = "chat" + s = fmt.Sprintf("[ERROR] %s", err.Error()) + kind = "error" } chatMsg := Message{ Kind: kind, @@ -212,7 +220,7 @@ func (api *Api) WsChat(c *gin.Context) { }() api.Logger.LogInfo(fmt.Sprintf("websocket connection open")) - cli := gogpt.NewClient(api.Config.AppKey) + cli := openai.NewClient(api.Config.AppKey) var latestRequestTime time.Time for { diff --git a/go.mod b/go.mod index 53975c8..bacbcbd 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/gin-gonic/gin v1.8.2 github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.5.0 - github.com/sashabaranov/go-gpt3 v1.3.3 + github.com/sashabaranov/go-openai v1.4.1 github.com/sirupsen/logrus v1.9.0 gopkg.in/yaml.v3 v3.0.1 )