mirror of
https://github.com/guohuadeng/app-odoo.git
synced 2025-02-23 04:11:36 +02:00
chatgpt增加azure支持
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import requests
|
||||
import openai
|
||||
from odoo import api, fields, models, _
|
||||
|
||||
|
||||
@@ -10,7 +11,7 @@ class AiRobot(models.Model):
|
||||
_order = 'sequence, name'
|
||||
|
||||
name = fields.Char(string='Name', translate=True)
|
||||
provider = fields.Selection(string="AI Provider", selection=[('openai', 'OpenAI')], required=True, default='openai')
|
||||
provider = fields.Selection(string="AI Provider", selection=[('openai', 'OpenAI'), ('azure', 'Azure')], required=True, default='openai')
|
||||
ai_model = fields.Selection(string="AI Model", selection=[
|
||||
('gpt-4', 'Chatgpt 4'),
|
||||
('gpt-3.5-turbo', 'Chatgpt 3.5 Turbo'),
|
||||
@@ -31,10 +32,19 @@ GPT-3 A set of models that can understand and generate natural language
|
||||
""")
|
||||
openapi_api_key = fields.Char(string="API Key", help="Provide the API key here")
|
||||
temperature = fields.Float(string='Temperature', default=0.9)
|
||||
|
||||
|
||||
max_length = fields.Integer('Max Length', default=100)
|
||||
sequence = fields.Integer('Sequence', help="Determine the display order", default=10)
|
||||
|
||||
def action_disconnect(self):
|
||||
requests.delete('https://chatgpt.com/v1/disconnect')
|
||||
|
||||
def get_openai(self, data):
|
||||
openai.api_type = self.provider
|
||||
openai.api_base = "https://odooapp.openai.azure.com/"
|
||||
openai.api_version = "2022-12-01"
|
||||
openai.api_key = self.openapi_api_key
|
||||
response = openai.Completion.create(engine='odooapp', prompt=data, temperature=self.temperature, max_tokens=self.max_length, top_p=0.5, frequency_penalty=0,
|
||||
presence_penalty=0, stop=["Human:", "AI:"])
|
||||
if 'choices' in response:
|
||||
res = response['choices'][0]['text'].replace(' .', '.').strip()
|
||||
return res
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
|
||||
import openai
|
||||
import requests,json
|
||||
import datetime
|
||||
@@ -15,7 +14,7 @@ class Channel(models.Model):
|
||||
_inherit = 'mail.channel'
|
||||
|
||||
@api.model
|
||||
def get_openai(self, api_key, ai_model, data, user="Odoo"):
|
||||
def get_openai(self, gpt_id, provider, api_key, ai_model, data, user="Odoo"):
|
||||
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
|
||||
R_TIMEOUT = 5
|
||||
|
||||
@@ -41,12 +40,16 @@ class Channel(models.Model):
|
||||
"user": user,
|
||||
"stop": ["Human:", "AI:"]
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
|
||||
res = response.json()
|
||||
if 'choices' in res:
|
||||
# for rec in res:
|
||||
# res = rec['message']['content']
|
||||
res = '\n'.join([x['message']['content'] for x in res['choices']])
|
||||
if provider == 'openai':
|
||||
response = requests.post("https://api.openai.com/v1/chat/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
|
||||
res = response.json()
|
||||
if 'choices' in res:
|
||||
# for rec in res:
|
||||
# res = rec['message']['content']
|
||||
res = '\n'.join([x['message']['content'] for x in res['choices']])
|
||||
return res
|
||||
elif provider == 'azure':
|
||||
res = gpt_id.get_openai(data)
|
||||
return res
|
||||
else:
|
||||
pdata = {
|
||||
@@ -60,10 +63,14 @@ class Channel(models.Model):
|
||||
"user": user,
|
||||
"stop": ["Human:", "AI:"]
|
||||
}
|
||||
response = requests.post("https://api.openai.com/v1/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
|
||||
res = response.json()
|
||||
if 'choices' in res:
|
||||
res = '\n'.join([x['text'] for x in res['choices']])
|
||||
if provider == 'openai':
|
||||
response = requests.post("https://api.openai.com/v1/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
|
||||
res = response.json()
|
||||
if 'choices' in res:
|
||||
res = '\n'.join([x['text'] for x in res['choices']])
|
||||
return res
|
||||
elif provider == 'azure':
|
||||
res = gpt_id.get_openai(data)
|
||||
return res
|
||||
# 获取模型信息
|
||||
# list_model = requests.get("https://api.openai.com/v1/models", headers=headers)
|
||||
@@ -178,6 +185,7 @@ class Channel(models.Model):
|
||||
# print(msg_vals.get('record_name', ''))
|
||||
# print('self.channel_type :',self.channel_type)
|
||||
if gpt_id:
|
||||
provider = gpt_id.provider
|
||||
ai_model = gpt_id.ai_model or 'text-davinci-003'
|
||||
# print('chatgpt_name:', chatgpt_name)
|
||||
# if author_id != to_partner_id.id and (chatgpt_name in msg_vals.get('record_name', '') or 'ChatGPT' in msg_vals.get('record_name', '') ) and self.channel_type == 'chat':
|
||||
@@ -189,7 +197,7 @@ class Channel(models.Model):
|
||||
prompt = self.get_openai_context(channel.id, to_partner_id.id, prompt, openapi_context_timeout)
|
||||
print(prompt)
|
||||
# res = self.get_chatgpt_answer(prompt,partner_name)
|
||||
res = self.get_openai(api_key, ai_model, prompt, partner_name)
|
||||
res = self.get_openai(gpt_id, provider, api_key, ai_model, prompt, partner_name)
|
||||
res = res.replace('\n', '<br/>')
|
||||
# print('res:',res)
|
||||
# print('channel:',channel)
|
||||
@@ -211,7 +219,7 @@ class Channel(models.Model):
|
||||
prompt = self.get_openai_context(chatgpt_channel_id.id, to_partner_id.id, prompt, openapi_context_timeout)
|
||||
# print(prompt)
|
||||
# res = self.get_chatgpt_answer(prompt, partner_name)
|
||||
res = self.get_openai(api_key, ai_model, prompt, partner_name)
|
||||
res = self.get_openai(gpt_id, provider, api_key, ai_model, prompt, partner_name)
|
||||
res = res.replace('\n', '<br/>')
|
||||
chatgpt_channel_id.with_user(user_id).message_post(body=res, message_type='comment', subtype_xmlid='mail.mt_comment',parent_id=message.id)
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user