mirror of
https://github.com/harry0703/MoneyPrinterTurbo.git
synced 2026-02-21 16:37:21 +08:00
added support for google gemini
This commit is contained in:
parent
478207fa7b
commit
cc1f157714
@ -5,9 +5,9 @@ from typing import List
|
||||
from loguru import logger
|
||||
from openai import OpenAI
|
||||
from openai import AzureOpenAI
|
||||
import google.generativeai as genai
|
||||
from app.config import config
|
||||
|
||||
|
||||
def _generate_response(prompt: str) -> str:
|
||||
content = ""
|
||||
llm_provider = config.app.get("llm_provider", "openai")
|
||||
@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str:
|
||||
model_name = config.app.get("azure_model_name")
|
||||
base_url = config.app.get("azure_base_url", "")
|
||||
api_version = config.app.get("azure_api_version", "2024-02-15-preview")
|
||||
elif llm_provider == "gemini":
|
||||
api_key = config.app.get("gemini_api_key")
|
||||
model_name = config.app.get("gemini_model_name")
|
||||
base_url = ""
|
||||
elif llm_provider == "qwen":
|
||||
api_key = config.app.get("qwen_api_key")
|
||||
model_name = config.app.get("qwen_model_name")
|
||||
@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str:
|
||||
content = response["output"]["text"]
|
||||
return content.replace("\n", "")
|
||||
|
||||
if llm_provider == "gemini":
|
||||
genai.configure(api_key=api_key)
|
||||
|
||||
generation_config = {
|
||||
"temperature": 0.5,
|
||||
"top_p": 1,
|
||||
"top_k": 1,
|
||||
"max_output_tokens": 2048,
|
||||
}
|
||||
|
||||
safety_settings = [
|
||||
{
|
||||
"category": "HARM_CATEGORY_HARASSMENT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
},
|
||||
{
|
||||
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
"threshold": "BLOCK_ONLY_HIGH"
|
||||
},
|
||||
]
|
||||
|
||||
model = genai.GenerativeModel(model_name=model_name,
|
||||
generation_config=generation_config,
|
||||
safety_settings=safety_settings)
|
||||
|
||||
convo = model.start_chat(history=[])
|
||||
|
||||
convo.send_message(prompt)
|
||||
return convo.last.text
|
||||
|
||||
if llm_provider == "azure":
|
||||
client = AzureOpenAI(
|
||||
api_key=api_key,
|
||||
|
||||
@ -51,6 +51,10 @@
|
||||
azure_model_name="gpt-35-turbo" # replace with your model deployment name
|
||||
azure_api_version = "2024-02-15-preview"
|
||||
|
||||
########## Gemini API Key
|
||||
gemini_api_key=""
|
||||
gemini_model_name = "gemini-1.0-pro"
|
||||
|
||||
########## Qwen API Key
|
||||
# Visit https://dashscope.console.aliyun.com/apiKey to get your API key
|
||||
# Visit below links to get more details
|
||||
|
||||
@ -13,4 +13,4 @@ urllib3~=2.2.1
|
||||
pillow~=9.5.0
|
||||
pydantic~=2.6.3
|
||||
g4f~=0.2.5.4
|
||||
dashscope~=1.15.0
|
||||
dashscope~=1.15.0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user