From cc1f157714f526902b65b7a4926dc6a14238f2fc Mon Sep 17 00:00:00 2001 From: PD <56485898+pratham-darooka@users.noreply.github.com> Date: Sun, 31 Mar 2024 10:44:52 +0530 Subject: [PATCH 1/3] added support for google gemini --- app/services/llm.py | 44 +++++++++++++++++++++++++++++++++++++++++++- config.example.toml | 4 ++++ requirements.txt | 2 +- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/app/services/llm.py b/app/services/llm.py index d361add..3c48c45 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -5,9 +5,9 @@ from typing import List from loguru import logger from openai import OpenAI from openai import AzureOpenAI +import google.generativeai as genai from app.config import config - def _generate_response(prompt: str) -> str: content = "" llm_provider = config.app.get("llm_provider", "openai") @@ -42,6 +42,10 @@ def _generate_response(prompt: str) -> str: model_name = config.app.get("azure_model_name") base_url = config.app.get("azure_base_url", "") api_version = config.app.get("azure_api_version", "2024-02-15-preview") + elif llm_provider == "gemini": + api_key = config.app.get("gemini_api_key") + model_name = config.app.get("gemini_model_name") + base_url = "" elif llm_provider == "qwen": api_key = config.app.get("qwen_api_key") model_name = config.app.get("qwen_model_name") @@ -66,6 +70,44 @@ def _generate_response(prompt: str) -> str: content = response["output"]["text"] return content.replace("\n", "") + if llm_provider == "gemini": + genai.configure(api_key=api_key) + + generation_config = { + "temperature": 0.5, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, + } + + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + ] + + model = genai.GenerativeModel(model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings) + + convo = model.start_chat(history=[]) + + convo.send_message(prompt) + return convo.last.text + if llm_provider == "azure": client = AzureOpenAI( api_key=api_key, diff --git a/config.example.toml b/config.example.toml index 14d257a..9a87a8a 100644 --- a/config.example.toml +++ b/config.example.toml @@ -51,6 +51,10 @@ azure_model_name="gpt-35-turbo" # replace with your model deployment name azure_api_version = "2024-02-15-preview" + ########## Gemini API Key + gemini_api_key="" + gemini_model_name = "gemini-1.0-pro" + ########## Qwen API Key # Visit https://dashscope.console.aliyun.com/apiKey to get your API key # Visit below links to get more details diff --git a/requirements.txt b/requirements.txt index 651f46a..2868f3b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ urllib3~=2.2.1 pillow~=9.5.0 pydantic~=2.6.3 g4f~=0.2.5.4 -dashscope~=1.15.0 \ No newline at end of file +dashscope~=1.15.0 From 1e2d63a1c0e34d9cb5f45a4bd9bde0203e88bda1 Mon Sep 17 00:00:00 2001 From: PD <56485898+pratham-darooka@users.noreply.github.com> Date: Sun, 31 Mar 2024 11:01:59 +0530 Subject: [PATCH 2/3] updated requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 2868f3b..d67aa75 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,3 +14,4 @@ pillow~=9.5.0 pydantic~=2.6.3 g4f~=0.2.5.4 dashscope~=1.15.0 +google.generativeai~=0.4.1 \ No newline at end of file From dacd35f5223cf69edef524dafaa787f5377c8882 Mon Sep 17 00:00:00 2001 From: PD <56485898+pratham-darooka@users.noreply.github.com> Date: Sun, 31 Mar 2024 11:07:19 +0530 Subject: [PATCH 3/3] fixed bug --- app/services/llm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/services/llm.py b/app/services/llm.py index 3c48c45..84ada77 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -45,7 +45,7 @@ def _generate_response(prompt: str) -> str: elif llm_provider == "gemini": api_key = config.app.get("gemini_api_key") model_name = config.app.get("gemini_model_name") - base_url = "" + base_url = "***" elif llm_provider == "qwen": api_key = config.app.get("qwen_api_key") model_name = config.app.get("qwen_model_name")