Merge remote-tracking branch 'origin/16.0' into 16.0

This commit is contained in:
ivan deng
2023-04-23 19:35:38 +08:00
6 changed files with 80 additions and 6 deletions

View File

@@ -139,8 +139,8 @@ GPT-3 A set of models that can understand and generate natural language
res = getattr(self, 'get_%s' % self.provider)(data, author_id, answer_id, param)
# 后置勾子,返回处理后的内容
res_post, is_ai = self.get_ai_post(res, author_id, answer_id, param)
return res_post, is_ai
res_post, usage, is_ai = self.get_ai_post(res, author_id, answer_id, param)
return res_post, usage, is_ai
def get_ai_post(self, res, author_id=False, answer_id=False, param={}):
if res and author_id and isinstance(res, openai.openai_object.OpenAIObject) or isinstance(res, list) or isinstance(res, dict):
@@ -184,10 +184,10 @@ GPT-3 A set of models that can understand and generate natural language
'first_ask_time': ask_date
})
ai_use.write(vals)
return data, True
return data, usage, True
else:
# 直接返回错误语句那么就是非ai
return res, False
return res, False, False
def get_ai_system(self, content=None):
# 获取基础ai角色设定, role system

View File

@@ -63,10 +63,19 @@ class Channel(models.Model):
answer_id = user_id.partner_id
# todo: 只有个人配置的群聊才给配置
param = self.get_ai_config(ai)
res, is_ai = ai.get_ai(messages, author_id, answer_id, param)
res, usage, is_ai = ai.get_ai(messages, author_id, answer_id, param)
if res:
res = res.replace('\n', '<br/>')
channel.with_user(user_id).message_post(body=res, message_type='comment', subtype_xmlid='mail.mt_comment', parent_id=message.id)
new_msg = channel.with_user(user_id).message_post(body=res, message_type='comment', subtype_xmlid='mail.mt_comment', parent_id=message.id)
if usage:
prompt_tokens = usage['prompt_tokens']
completion_tokens = usage['completion_tokens']
total_tokens = usage['total_tokens']
new_msg.write({
'human_prompt_tokens': prompt_tokens,
'ai_completion_tokens': completion_tokens,
'cost_tokens': total_tokens,
})
def _notify_thread(self, message, msg_vals=False, **kwargs):
rdata = super(Channel, self)._notify_thread(message, msg_vals=msg_vals, **kwargs)

View File

@@ -6,8 +6,22 @@ from odoo import fields, models
class Message(models.Model):
_inherit = "mail.message"
human_prompt_tokens = fields.Integer('Human Prompt Tokens')
ai_completion_tokens = fields.Integer('AI Completion Tokens')
cost_tokens = fields.Integer('Cost Tokens')
def _message_add_reaction(self, content):
super(Message, self)._message_add_reaction(content)
if self.create_uid.gpt_id:
# 处理反馈
pass
def message_format(self, format_reply=True):
message_values = super(Message, self).message_format(format_reply=format_reply)
for message in message_values:
message_sudo = self.browse(message['id']).sudo().with_prefetch(self.ids)
message['human_prompt_tokens'] = message_sudo.human_prompt_tokens
message['ai_completion_tokens'] = message_sudo.ai_completion_tokens
message['cost_tokens'] = message_sudo.cost_tokens
return message_values