406 lines
10 KiB
Go
406 lines
10 KiB
Go
package chat
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/uuid"
|
|
"github.com/gorilla/websocket"
|
|
openai "github.com/sashabaranov/go-openai"
|
|
)
|
|
|
|
type Api struct {
|
|
Config Config
|
|
Logger
|
|
}
|
|
|
|
type ApiResponse struct {
|
|
Status string `yaml:"status" json:"status" bson:"status" validate:""`
|
|
Msg string `yaml:"msg" json:"msg" bson:"msg" validate:""`
|
|
Duration string `yaml:"duration" json:"duration" bson:"duration" validate:""`
|
|
Data interface{} `yaml:"data" json:"data" bson:"data" validate:""`
|
|
}
|
|
|
|
type Message struct {
|
|
Msg string `yaml:"msg" json:"msg" bson:"msg" validate:""`
|
|
MsgId string `yaml:"msgId" json:"msgId" bson:"msgId" validate:""`
|
|
Kind string `yaml:"kind" json:"kind" bson:"kind" validate:""`
|
|
CreateTime string `yaml:"createTime" json:"createTime" bson:"createTime" validate:""`
|
|
}
|
|
|
|
func (api *Api) responseFunc(c *gin.Context, startTime time.Time, status, msg string, httpStatus int, data map[string]interface{}) {
|
|
duration := time.Since(startTime)
|
|
ar := ApiResponse{
|
|
Status: status,
|
|
Msg: msg,
|
|
Duration: duration.String(),
|
|
Data: data,
|
|
}
|
|
c.JSON(httpStatus, ar)
|
|
}
|
|
|
|
func (api *Api) wsPingMsg(conn *websocket.Conn, chClose, chIsCloseSet chan int) {
|
|
var err error
|
|
ticker := time.NewTicker(PingPeriod)
|
|
|
|
var mutex = &sync.Mutex{}
|
|
|
|
defer func() {
|
|
ticker.Stop()
|
|
conn.Close()
|
|
}()
|
|
for {
|
|
select {
|
|
case <-ticker.C:
|
|
conn.SetWriteDeadline(time.Now().Add(PingWait))
|
|
mutex.Lock()
|
|
err = conn.WriteMessage(websocket.PingMessage, nil)
|
|
if err != nil {
|
|
return
|
|
}
|
|
mutex.Unlock()
|
|
case <-chClose:
|
|
api.LogInfo(fmt.Sprintf("# websocket connection closed"))
|
|
chIsCloseSet <- 0
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) {
|
|
var err error
|
|
var strResp string
|
|
|
|
ctx := context.Background()
|
|
|
|
switch api.Config.Model {
|
|
case openai.GPT3Dot5Turbo0301, openai.GPT3Dot5Turbo, openai.GPT4, openai.GPT40314, openai.GPT432K0314, openai.GPT432K:
|
|
req := openai.ChatCompletionRequest{
|
|
Model: api.Config.Model,
|
|
MaxTokens: api.Config.MaxLength,
|
|
Temperature: 1.0,
|
|
Messages: []openai.ChatCompletionMessage{
|
|
{
|
|
Role: openai.ChatMessageRoleUser,
|
|
Content: requestMsg,
|
|
},
|
|
},
|
|
Stream: true,
|
|
TopP: 1,
|
|
FrequencyPenalty: 0.1,
|
|
PresencePenalty: 0.1,
|
|
}
|
|
|
|
stream, err := cli.CreateChatCompletionStream(ctx, req)
|
|
if err != nil {
|
|
err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error())
|
|
chatMsg := Message{
|
|
Kind: "error",
|
|
Msg: err.Error(),
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
api.Logger.LogError(err.Error())
|
|
return
|
|
}
|
|
defer stream.Close()
|
|
|
|
id := uuid.New().String()
|
|
var i int
|
|
for {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
var s string
|
|
var kind string
|
|
if errors.Is(err, io.EOF) {
|
|
if i == 0 {
|
|
s = "[ERROR] NO RESPONSE, PLEASE RETRY"
|
|
kind = "retry"
|
|
} else {
|
|
s = "\n\n###### [END] ######"
|
|
kind = "chat"
|
|
}
|
|
} else {
|
|
s = fmt.Sprintf("[ERROR] %s", err.Error())
|
|
kind = "error"
|
|
}
|
|
chatMsg := Message{
|
|
Kind: kind,
|
|
Msg: s,
|
|
MsgId: id,
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
break
|
|
}
|
|
|
|
if len(response.Choices) > 0 {
|
|
var s string
|
|
if i == 0 {
|
|
s = fmt.Sprintf(`%s# %s`, s, requestMsg)
|
|
}
|
|
for _, choice := range response.Choices {
|
|
s = s + choice.Delta.Content
|
|
}
|
|
strResp = strResp + s
|
|
chatMsg := Message{
|
|
Kind: "chat",
|
|
Msg: s,
|
|
MsgId: id,
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
}
|
|
i = i + 1
|
|
}
|
|
if strResp != "" {
|
|
api.Logger.LogInfo(fmt.Sprintf("[RESPONSE] %s\n", strResp))
|
|
}
|
|
case openai.GPT3TextDavinci003, openai.GPT3TextDavinci002, openai.GPT3TextCurie001, openai.GPT3TextBabbage001, openai.GPT3TextAda001, openai.GPT3TextDavinci001, openai.GPT3DavinciInstructBeta, openai.GPT3Davinci, openai.GPT3CurieInstructBeta, openai.GPT3Curie, openai.GPT3Ada, openai.GPT3Babbage:
|
|
req := openai.CompletionRequest{
|
|
Model: api.Config.Model,
|
|
MaxTokens: api.Config.MaxLength,
|
|
Temperature: 0.6,
|
|
Prompt: requestMsg,
|
|
Stream: true,
|
|
//Stop: []string{"\n\n\n"},
|
|
TopP: 1,
|
|
FrequencyPenalty: 0.1,
|
|
PresencePenalty: 0.1,
|
|
}
|
|
|
|
stream, err := cli.CreateCompletionStream(ctx, req)
|
|
if err != nil {
|
|
err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error())
|
|
chatMsg := Message{
|
|
Kind: "error",
|
|
Msg: err.Error(),
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
api.Logger.LogError(err.Error())
|
|
return
|
|
}
|
|
defer stream.Close()
|
|
|
|
id := uuid.New().String()
|
|
var i int
|
|
for {
|
|
response, err := stream.Recv()
|
|
if err != nil {
|
|
var s string
|
|
var kind string
|
|
if errors.Is(err, io.EOF) {
|
|
if i == 0 {
|
|
s = "[ERROR] NO RESPONSE, PLEASE RETRY"
|
|
kind = "retry"
|
|
} else {
|
|
s = "\n\n###### [END] ######"
|
|
kind = "chat"
|
|
}
|
|
} else {
|
|
s = fmt.Sprintf("[ERROR] %s", err.Error())
|
|
kind = "error"
|
|
}
|
|
chatMsg := Message{
|
|
Kind: kind,
|
|
Msg: s,
|
|
MsgId: id,
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
break
|
|
}
|
|
|
|
if len(response.Choices) > 0 {
|
|
var s string
|
|
if i == 0 {
|
|
s = fmt.Sprintf(`%s# %s`, s, requestMsg)
|
|
}
|
|
for _, choice := range response.Choices {
|
|
s = s + choice.Text
|
|
}
|
|
strResp = strResp + s
|
|
chatMsg := Message{
|
|
Kind: "chat",
|
|
Msg: s,
|
|
MsgId: id,
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
}
|
|
i = i + 1
|
|
}
|
|
if strResp != "" {
|
|
api.Logger.LogInfo(fmt.Sprintf("[RESPONSE] %s\n", strResp))
|
|
}
|
|
default:
|
|
err = fmt.Errorf("model not exists")
|
|
api.Logger.LogError(err.Error())
|
|
return
|
|
}
|
|
}
|
|
|
|
func (api *Api) WsChat(c *gin.Context) {
|
|
startTime := time.Now()
|
|
status := StatusFail
|
|
msg := ""
|
|
httpStatus := http.StatusForbidden
|
|
data := map[string]interface{}{}
|
|
|
|
wsupgrader := websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
}
|
|
wsupgrader.CheckOrigin = func(r *http.Request) bool {
|
|
return true
|
|
}
|
|
mutex := &sync.Mutex{}
|
|
conn, err := wsupgrader.Upgrade(c.Writer, c.Request, nil)
|
|
if err != nil {
|
|
err = fmt.Errorf("[ERROR] failed to upgrade websocket %s", err.Error())
|
|
msg = err.Error()
|
|
api.responseFunc(c, startTime, status, msg, httpStatus, data)
|
|
return
|
|
}
|
|
defer func() {
|
|
_ = conn.Close()
|
|
}()
|
|
|
|
_ = conn.SetReadDeadline(time.Now().Add(PingWait))
|
|
conn.SetPongHandler(func(s string) error {
|
|
_ = conn.SetReadDeadline(time.Now().Add(PingWait))
|
|
return nil
|
|
})
|
|
|
|
var isClosed bool
|
|
chClose := make(chan int)
|
|
chIsCloseSet := make(chan int)
|
|
defer func() {
|
|
conn.Close()
|
|
}()
|
|
go api.wsPingMsg(conn, chClose, chIsCloseSet)
|
|
go func() {
|
|
for {
|
|
select {
|
|
case <-chIsCloseSet:
|
|
isClosed = true
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
api.Logger.LogInfo(fmt.Sprintf("websocket connection open"))
|
|
cli := openai.NewClient(api.Config.ApiKey)
|
|
|
|
var latestRequestTime time.Time
|
|
for {
|
|
if isClosed {
|
|
return
|
|
}
|
|
// read in a message
|
|
messageType, bs, err := conn.ReadMessage()
|
|
if err != nil {
|
|
err = fmt.Errorf("[ERROR] read message error: %s", err.Error())
|
|
api.Logger.LogError(err.Error())
|
|
return
|
|
}
|
|
switch messageType {
|
|
case websocket.TextMessage:
|
|
requestMsg := string(bs)
|
|
api.Logger.LogInfo(fmt.Sprintf("[REQUEST] %s", requestMsg))
|
|
var ok bool
|
|
if latestRequestTime.IsZero() {
|
|
latestRequestTime = time.Now()
|
|
ok = true
|
|
} else {
|
|
if time.Since(latestRequestTime) < time.Second*time.Duration(api.Config.IntervalSeconds) {
|
|
err = fmt.Errorf("[ERROR] please wait %d seconds for next query", api.Config.IntervalSeconds)
|
|
chatMsg := Message{
|
|
Kind: "error",
|
|
Msg: err.Error(),
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
api.Logger.LogError(err.Error())
|
|
} else {
|
|
ok = true
|
|
latestRequestTime = time.Now()
|
|
}
|
|
}
|
|
if ok {
|
|
if len(strings.Trim(requestMsg, " ")) < 2 {
|
|
err = fmt.Errorf("[ERROR] message too short")
|
|
chatMsg := Message{
|
|
Kind: "error",
|
|
Msg: err.Error(),
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
api.Logger.LogError(err.Error())
|
|
} else {
|
|
chatMsg := Message{
|
|
Kind: "receive",
|
|
Msg: requestMsg,
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
go api.GetChatMessage(conn, cli, mutex, requestMsg)
|
|
}
|
|
}
|
|
case websocket.CloseMessage:
|
|
isClosed = true
|
|
api.Logger.LogInfo("[CLOSED] websocket receive closed message")
|
|
case websocket.PingMessage:
|
|
_ = conn.SetReadDeadline(time.Now().Add(PingWait))
|
|
api.Logger.LogInfo("[PING] websocket receive ping message")
|
|
case websocket.PongMessage:
|
|
_ = conn.SetReadDeadline(time.Now().Add(PingWait))
|
|
api.Logger.LogInfo("[PONG] websocket receive pong message")
|
|
default:
|
|
err = fmt.Errorf("[ERROR] websocket receive message type not text")
|
|
chatMsg := Message{
|
|
Kind: "error",
|
|
Msg: err.Error(),
|
|
MsgId: uuid.New().String(),
|
|
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
|
|
}
|
|
mutex.Lock()
|
|
_ = conn.WriteJSON(chatMsg)
|
|
mutex.Unlock()
|
|
api.Logger.LogError(err.Error())
|
|
return
|
|
}
|
|
}
|
|
}
|