From 3315e5940f688ece93fb24f3989a0966be2a357b Mon Sep 17 00:00:00 2001 From: cookeem Date: Wed, 22 Mar 2023 23:04:05 +0800 Subject: [PATCH] support generate image by prompt --- chat/common.go | 18 ++++++++ chat/service.go | 114 +++++++++++++++++++++++++++++++++++++++++++----- main.go | 1 + 3 files changed, 122 insertions(+), 11 deletions(-) diff --git a/chat/common.go b/chat/common.go index dca30db..415964c 100644 --- a/chat/common.go +++ b/chat/common.go @@ -1,8 +1,10 @@ package chat import ( + "fmt" "github.com/sashabaranov/go-openai" log "github.com/sirupsen/logrus" + "math/rand" "os" "time" ) @@ -39,6 +41,22 @@ func (logger Logger) LogPanic(args ...interface{}) { log.Panic(args...) } +func RandomString(n int) string { + var letter []rune + lowerChars := "abcdefghijklmnopqrstuvwxyz" + numberChars := "0123456789" + chars := fmt.Sprintf("%s%s", lowerChars, numberChars) + letter = []rune(chars) + var str string + b := make([]rune, n) + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + for i := range b { + b[i] = letter[seededRand.Intn(len(letter))] + } + str = string(b) + return str +} + const ( StatusFail string = "FAIL" diff --git a/chat/service.go b/chat/service.go index ae9db64..0b16c55 100644 --- a/chat/service.go +++ b/chat/service.go @@ -2,10 +2,12 @@ package chat import ( "context" + "encoding/base64" "errors" "fmt" "io" "net/http" + "os" "strings" "sync" "time" @@ -100,7 +102,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * stream, err := cli.CreateChatCompletionStream(ctx, req) if err != nil { - err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error()) + err = fmt.Errorf("[ERROR] create ChatGPT stream model=%s error: %s", api.Config.Model, err.Error()) chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -185,7 +187,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * stream, err := cli.CreateCompletionStream(ctx, req) if err != nil { - err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error()) + err = fmt.Errorf("[ERROR] create ChatGPT stream model=%s error: %s", api.Config.Model, err.Error()) chatMsg := Message{ Kind: "error", Msg: err.Error(), @@ -262,6 +264,83 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex * } } +func (api *Api) GetImageMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) { + var err error + + ctx := context.Background() + + prompt := strings.TrimPrefix(requestMsg, "/image ") + req := openai.ImageRequest{ + Prompt: prompt, + Size: openai.CreateImageSize256x256, + ResponseFormat: openai.CreateImageResponseFormatB64JSON, + N: 1, + } + + sendError := func(err error) { + err = fmt.Errorf("[ERROR] generate image error: %s", err.Error()) + chatMsg := Message{ + Kind: "error", + Msg: err.Error(), + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogError(err.Error()) + } + + resp, err := cli.CreateImage(ctx, req) + if err != nil { + err = fmt.Errorf("[ERROR] generate image error: %s", err.Error()) + sendError(err) + return + } + if len(resp.Data) == 0 { + err = fmt.Errorf("[ERROR] generate image error: result is empty") + sendError(err) + return + } + + imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON) + if err != nil { + err = fmt.Errorf("[ERROR] image base64 decode error: %s", err.Error()) + sendError(err) + return + } + + date := time.Now().Format("2006-01-02") + imageDir := fmt.Sprintf("assets/images/%s", date) + err = os.MkdirAll(imageDir, 0700) + if err != nil { + err = fmt.Errorf("[ERROR] create image directory error: %s", err.Error()) + sendError(err) + return + } + + imageFileName := fmt.Sprintf("%s.png", RandomString(16)) + err = os.WriteFile(fmt.Sprintf("%s/%s", imageDir, imageFileName), imgBytes, 0600) + if err != nil { + err = fmt.Errorf("[ERROR] write png image error: %s", err.Error()) + sendError(err) + return + } + + msg := fmt.Sprintf("api/%s/%s", imageDir, imageFileName) + chatMsg := Message{ + Kind: "image", + Msg: msg, + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + api.Logger.LogInfo(fmt.Sprintf("[IMAGE] # %s\n%s", requestMsg, msg)) + return +} + func (api *Api) WsChat(c *gin.Context) { startTime := time.Now() status := StatusFail @@ -366,16 +445,29 @@ func (api *Api) WsChat(c *gin.Context) { mutex.Unlock() api.Logger.LogError(err.Error()) } else { - chatMsg := Message{ - Kind: "receive", - Msg: requestMsg, - MsgId: uuid.New().String(), - CreateTime: time.Now().Format("2006-01-02 15:04:05"), + if strings.HasPrefix(requestMsg, "/image ") { + chatMsg := Message{ + Kind: "receive", + Msg: requestMsg, + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + go api.GetImageMessage(conn, cli, mutex, requestMsg) + } else { + chatMsg := Message{ + Kind: "receive", + Msg: requestMsg, + MsgId: uuid.New().String(), + CreateTime: time.Now().Format("2006-01-02 15:04:05"), + } + mutex.Lock() + _ = conn.WriteJSON(chatMsg) + mutex.Unlock() + go api.GetChatMessage(conn, cli, mutex, requestMsg) } - mutex.Lock() - _ = conn.WriteJSON(chatMsg) - mutex.Unlock() - go api.GetChatMessage(conn, cli, mutex, requestMsg) } } case websocket.CloseMessage: diff --git a/main.go b/main.go index ec097c4..687c2ef 100644 --- a/main.go +++ b/main.go @@ -56,6 +56,7 @@ func main() { } groupApi := r.Group("/api") + groupApi.Static("/assets", "assets") groupWs := groupApi.Group("/ws") groupWs.GET("chat", api.WsChat)