support generate image by prompt

This commit is contained in:
cookeem
2023-03-22 23:04:05 +08:00
parent 38f7a73288
commit 3315e5940f
3 changed files with 122 additions and 11 deletions

View File

@@ -1,8 +1,10 @@
package chat
import (
"fmt"
"github.com/sashabaranov/go-openai"
log "github.com/sirupsen/logrus"
"math/rand"
"os"
"time"
)
@@ -39,6 +41,22 @@ func (logger Logger) LogPanic(args ...interface{}) {
log.Panic(args...)
}
func RandomString(n int) string {
var letter []rune
lowerChars := "abcdefghijklmnopqrstuvwxyz"
numberChars := "0123456789"
chars := fmt.Sprintf("%s%s", lowerChars, numberChars)
letter = []rune(chars)
var str string
b := make([]rune, n)
seededRand := rand.New(rand.NewSource(time.Now().UnixNano()))
for i := range b {
b[i] = letter[seededRand.Intn(len(letter))]
}
str = string(b)
return str
}
const (
StatusFail string = "FAIL"

View File

@@ -2,10 +2,12 @@ package chat
import (
"context"
"encoding/base64"
"errors"
"fmt"
"io"
"net/http"
"os"
"strings"
"sync"
"time"
@@ -100,7 +102,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
stream, err := cli.CreateChatCompletionStream(ctx, req)
if err != nil {
err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error())
err = fmt.Errorf("[ERROR] create ChatGPT stream model=%s error: %s", api.Config.Model, err.Error())
chatMsg := Message{
Kind: "error",
Msg: err.Error(),
@@ -185,7 +187,7 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
stream, err := cli.CreateCompletionStream(ctx, req)
if err != nil {
err = fmt.Errorf("[ERROR] create chatGPT stream model=%s error: %s", api.Config.Model, err.Error())
err = fmt.Errorf("[ERROR] create ChatGPT stream model=%s error: %s", api.Config.Model, err.Error())
chatMsg := Message{
Kind: "error",
Msg: err.Error(),
@@ -262,6 +264,83 @@ func (api *Api) GetChatMessage(conn *websocket.Conn, cli *openai.Client, mutex *
}
}
func (api *Api) GetImageMessage(conn *websocket.Conn, cli *openai.Client, mutex *sync.Mutex, requestMsg string) {
var err error
ctx := context.Background()
prompt := strings.TrimPrefix(requestMsg, "/image ")
req := openai.ImageRequest{
Prompt: prompt,
Size: openai.CreateImageSize256x256,
ResponseFormat: openai.CreateImageResponseFormatB64JSON,
N: 1,
}
sendError := func(err error) {
err = fmt.Errorf("[ERROR] generate image error: %s", err.Error())
chatMsg := Message{
Kind: "error",
Msg: err.Error(),
MsgId: uuid.New().String(),
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
}
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
api.Logger.LogError(err.Error())
}
resp, err := cli.CreateImage(ctx, req)
if err != nil {
err = fmt.Errorf("[ERROR] generate image error: %s", err.Error())
sendError(err)
return
}
if len(resp.Data) == 0 {
err = fmt.Errorf("[ERROR] generate image error: result is empty")
sendError(err)
return
}
imgBytes, err := base64.StdEncoding.DecodeString(resp.Data[0].B64JSON)
if err != nil {
err = fmt.Errorf("[ERROR] image base64 decode error: %s", err.Error())
sendError(err)
return
}
date := time.Now().Format("2006-01-02")
imageDir := fmt.Sprintf("assets/images/%s", date)
err = os.MkdirAll(imageDir, 0700)
if err != nil {
err = fmt.Errorf("[ERROR] create image directory error: %s", err.Error())
sendError(err)
return
}
imageFileName := fmt.Sprintf("%s.png", RandomString(16))
err = os.WriteFile(fmt.Sprintf("%s/%s", imageDir, imageFileName), imgBytes, 0600)
if err != nil {
err = fmt.Errorf("[ERROR] write png image error: %s", err.Error())
sendError(err)
return
}
msg := fmt.Sprintf("api/%s/%s", imageDir, imageFileName)
chatMsg := Message{
Kind: "image",
Msg: msg,
MsgId: uuid.New().String(),
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
}
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
api.Logger.LogInfo(fmt.Sprintf("[IMAGE] # %s\n%s", requestMsg, msg))
return
}
func (api *Api) WsChat(c *gin.Context) {
startTime := time.Now()
status := StatusFail
@@ -366,16 +445,29 @@ func (api *Api) WsChat(c *gin.Context) {
mutex.Unlock()
api.Logger.LogError(err.Error())
} else {
chatMsg := Message{
Kind: "receive",
Msg: requestMsg,
MsgId: uuid.New().String(),
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
if strings.HasPrefix(requestMsg, "/image ") {
chatMsg := Message{
Kind: "receive",
Msg: requestMsg,
MsgId: uuid.New().String(),
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
}
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
go api.GetImageMessage(conn, cli, mutex, requestMsg)
} else {
chatMsg := Message{
Kind: "receive",
Msg: requestMsg,
MsgId: uuid.New().String(),
CreateTime: time.Now().Format("2006-01-02 15:04:05"),
}
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
go api.GetChatMessage(conn, cli, mutex, requestMsg)
}
mutex.Lock()
_ = conn.WriteJSON(chatMsg)
mutex.Unlock()
go api.GetChatMessage(conn, cli, mutex, requestMsg)
}
}
case websocket.CloseMessage:

View File

@@ -56,6 +56,7 @@ func main() {
}
groupApi := r.Group("/api")
groupApi.Static("/assets", "assets")
groupWs := groupApi.Group("/ws")
groupWs.GET("chat", api.WsChat)