chatgpt增加azure支持

This commit is contained in:
Chill
2023-03-16 16:20:59 +08:00
parent e0d854dcf9
commit 3ee79fd7fb
7 changed files with 70 additions and 21 deletions

View File

@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-
import openai
import requests,json
import datetime
@@ -15,7 +14,7 @@ class Channel(models.Model):
_inherit = 'mail.channel'
@api.model
def get_openai(self, api_key, ai_model, data, user="Odoo"):
def get_openai(self, gpt_id, provider, api_key, ai_model, data, user="Odoo"):
headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}
R_TIMEOUT = 5
@@ -41,12 +40,16 @@ class Channel(models.Model):
"user": user,
"stop": ["Human:", "AI:"]
}
response = requests.post("https://api.openai.com/v1/chat/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
res = response.json()
if 'choices' in res:
# for rec in res:
# res = rec['message']['content']
res = '\n'.join([x['message']['content'] for x in res['choices']])
if provider == 'openai':
response = requests.post("https://api.openai.com/v1/chat/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
res = response.json()
if 'choices' in res:
# for rec in res:
# res = rec['message']['content']
res = '\n'.join([x['message']['content'] for x in res['choices']])
return res
elif provider == 'azure':
res = gpt_id.get_openai(data)
return res
else:
pdata = {
@@ -60,10 +63,14 @@ class Channel(models.Model):
"user": user,
"stop": ["Human:", "AI:"]
}
response = requests.post("https://api.openai.com/v1/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
res = response.json()
if 'choices' in res:
res = '\n'.join([x['text'] for x in res['choices']])
if provider == 'openai':
response = requests.post("https://api.openai.com/v1/completions", data=json.dumps(pdata), headers=headers, timeout=R_TIMEOUT)
res = response.json()
if 'choices' in res:
res = '\n'.join([x['text'] for x in res['choices']])
return res
elif provider == 'azure':
res = gpt_id.get_openai(data)
return res
# 获取模型信息
# list_model = requests.get("https://api.openai.com/v1/models", headers=headers)
@@ -178,6 +185,7 @@ class Channel(models.Model):
# print(msg_vals.get('record_name', ''))
# print('self.channel_type :',self.channel_type)
if gpt_id:
provider = gpt_id.provider
ai_model = gpt_id.ai_model or 'text-davinci-003'
# print('chatgpt_name:', chatgpt_name)
# if author_id != to_partner_id.id and (chatgpt_name in msg_vals.get('record_name', '') or 'ChatGPT' in msg_vals.get('record_name', '') ) and self.channel_type == 'chat':
@@ -189,7 +197,7 @@ class Channel(models.Model):
prompt = self.get_openai_context(channel.id, to_partner_id.id, prompt, openapi_context_timeout)
print(prompt)
# res = self.get_chatgpt_answer(prompt,partner_name)
res = self.get_openai(api_key, ai_model, prompt, partner_name)
res = self.get_openai(gpt_id, provider, api_key, ai_model, prompt, partner_name)
res = res.replace('\n', '<br/>')
# print('res:',res)
# print('channel:',channel)
@@ -211,7 +219,7 @@ class Channel(models.Model):
prompt = self.get_openai_context(chatgpt_channel_id.id, to_partner_id.id, prompt, openapi_context_timeout)
# print(prompt)
# res = self.get_chatgpt_answer(prompt, partner_name)
res = self.get_openai(api_key, ai_model, prompt, partner_name)
res = self.get_openai(gpt_id, provider, api_key, ai_model, prompt, partner_name)
res = res.replace('\n', '<br/>')
chatgpt_channel_id.with_user(user_id).message_post(body=res, message_type='comment', subtype_xmlid='mail.mt_comment',parent_id=message.id)
except Exception as e: