diff --git a/app/services/llm.py b/app/services/llm.py index d361add..3c48c45 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -5,9 +5,9 @@ from typing import List from loguru import logger from openai import OpenAI from openai import AzureOpenAI +import google.generativeai as genai from app.config import config - def _generate_response(prompt: str) -> str: content = "" llm_provider = config.app.get("llm_provider", "openai") @@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str: model_name = config.app.get("azure_model_name") base_url = config.app.get("azure_base_url", "") api_version = config.app.get("azure_api_version", "2024-02-15-preview") + elif llm_provider == "gemini": + api_key = config.app.get("gemini_api_key") + model_name = config.app.get("gemini_model_name") + base_url = "" elif llm_provider == "qwen": api_key = config.app.get("qwen_api_key") model_name = config.app.get("qwen_model_name") @@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str: content = response["output"]["text"] return content.replace("\n", "") + if llm_provider == "gemini": + genai.configure(api_key=api_key) + + generation_config = { + "temperature": 0.5, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, + } + + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + ] + + model = genai.GenerativeModel(model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings) + + convo = model.start_chat(history=[]) + + convo.send_message(prompt) + return convo.last.text + if llm_provider == "azure": client = AzureOpenAI( api_key=api_key, diff --git a/config.example.toml b/config.example.toml index 14d257a..9a87a8a 100644 --- a/config.example.toml +++ b/config.example.toml @@ -51,6 +51,10 @@ azure_model_name="gpt-35-turbo" # replace with your model deployment name azure_api_version = "2024-02-15-preview" + ########## Gemini API Key + gemini_api_key="" + gemini_model_name = "gemini-1.0-pro" + ########## Qwen API Key # Visit https://dashscope.console.aliyun.com/apiKey to get your API key # Visit below links to get more details diff --git a/requirements.txt b/requirements.txt index 651f46a..2868f3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ urllib3~=2.2.1 pillow~=9.5.0 pydantic~=2.6.3 g4f~=0.2.5.4 -dashscope~=1.15.0 \ No newline at end of file +dashscope~=1.15.0