added support for google gemini

This commit is contained in:
PD 2024-03-31 10:44:52 +05:30 committed by GitHub
parent 478207fa7b
commit cc1f157714
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 48 additions and 2 deletions

View File

@ -5,9 +5,9 @@ from typing import List
from loguru import logger
from openai import OpenAI
from openai import AzureOpenAI
import google.generativeai as genai
from app.config import config
def _generate_response(prompt: str) -> str:
content = ""
llm_provider = config.app.get("llm_provider", "openai")
@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str:
model_name = config.app.get("azure_model_name")
base_url = config.app.get("azure_base_url", "")
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
elif llm_provider == "gemini":
api_key = config.app.get("gemini_api_key")
model_name = config.app.get("gemini_model_name")
base_url = ""
elif llm_provider == "qwen":
api_key = config.app.get("qwen_api_key")
model_name = config.app.get("qwen_model_name")
@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str:
content = response["output"]["text"]
return content.replace("\n", "")
if llm_provider == "gemini":
genai.configure(api_key=api_key)
generation_config = {
"temperature": 0.5,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 2048,
}
safety_settings = [
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_ONLY_HIGH"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_ONLY_HIGH"
},
]
model = genai.GenerativeModel(model_name=model_name,
generation_config=generation_config,
safety_settings=safety_settings)
convo = model.start_chat(history=[])
convo.send_message(prompt)
return convo.last.text
if llm_provider == "azure":
client = AzureOpenAI(
api_key=api_key,

View File

@ -51,6 +51,10 @@
azure_model_name="gpt-35-turbo" # replace with your model deployment name
azure_api_version = "2024-02-15-preview"
########## Gemini API Key
gemini_api_key=""
gemini_model_name = "gemini-1.0-pro"
########## Qwen API Key
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
# Visit below links to get more details

View File

@ -13,4 +13,4 @@ urllib3~=2.2.1
pillow~=9.5.0
pydantic~=2.6.3
g4f~=0.2.5.4
dashscope~=1.15.0
dashscope~=1.15.0