diff --git a/app_chatgpt/models/ai_robot.py b/app_chatgpt/models/ai_robot.py index d1dddad7..0ba494a9 100644 --- a/app_chatgpt/models/ai_robot.py +++ b/app_chatgpt/models/ai_robot.py @@ -133,6 +133,10 @@ GPT-3 A set of models that can understand and generate natural language # hook,都正常 return False + def get_msg_files_content(self, message): + # hook + return False + def get_ai(self, data, author_id=False, answer_id=False, param={}): # 通用方法 # author_id: 请求的 partner_id 对象 diff --git a/app_chatgpt/models/discuss_channel.py b/app_chatgpt/models/discuss_channel.py index c850ce8c..f9215e5d 100644 --- a/app_chatgpt/models/discuss_channel.py +++ b/app_chatgpt/models/discuss_channel.py @@ -167,7 +167,6 @@ class Channel(models.Model): def _notify_thread(self, message, msg_vals=False, **kwargs): rdata = super(Channel, self)._notify_thread(message, msg_vals=msg_vals, **kwargs) - # print(f'rdata:{rdata}') answer_id = self.env['res.partner'] user_id = self.env['res.users'] author_id = msg_vals.get('author_id') @@ -265,6 +264,7 @@ class Channel(models.Model): if not api_key: _logger.warning(_("ChatGPT Robot【%s】have not set open api key.")) return rdata + try: openapi_context_timeout = int(self.env['ir.config_parameter'].sudo().get_param('app_chatgpt.openapi_context_timeout')) or 60 except: @@ -294,6 +294,11 @@ class Channel(models.Model): if hasattr(channel, 'is_private') and channel.description: messages.append({"role": "system", "content": channel.description}) + if message.attachment_ids: + file_content = ai.get_msg_files_content(message) + if file_content: + messages.append({"role": "system", "content": file_content}) + try: c_history = self.get_openai_context(channel.id, author_id, answer_id, openapi_context_timeout, chat_count) if c_history: @@ -301,16 +306,16 @@ class Channel(models.Model): messages.append({"role": "user", "content": msg}) msg_len = sum(len(str(m)) for m in messages) # 接口最大接收 8430 Token + # if msg_len * 2 > ai.max_send_char: + # messages = [] + # if hasattr(channel, 'is_private') and channel.description: + # messages.append({"role": "system", "content": channel.description}) + # messages.append({"role": "user", "content": msg}) + msg_len = sum(len(str(m)) for m in messages) if msg_len * 2 > ai.max_send_char: - messages = [] - if hasattr(channel, 'is_private') and channel.description: - messages.append({"role": "system", "content": channel.description}) - messages.append({"role": "user", "content": msg}) - msg_len = sum(len(str(m)) for m in messages) - if msg_len * 2 > ai.max_send_char: - new_msg = channel.with_user(user_id).message_post(body=_('您所发送的提示词已超长。'), message_type='comment', - subtype_xmlid='mail.mt_comment', - parent_id=message.id) + new_msg = channel.with_user(user_id).message_post(body=_('您所发送的提示词已超长。'), message_type='comment', + subtype_xmlid='mail.mt_comment', + parent_id=message.id) # if msg_len * 2 >= 8000: # messages = [{"role": "user", "content": msg}]