diff --git a/app/services/llm.py b/app/services/llm.py index ea5bc30..fca6bca 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -5,9 +5,9 @@ from typing import List from loguru import logger from openai import OpenAI from openai import AzureOpenAI -import google.generativeai as genai from app.config import config + def _generate_response(prompt: str) -> str: content = "" llm_provider = config.app.get("llm_provider", "openai") @@ -29,7 +29,7 @@ def _generate_response(prompt: str) -> str: base_url = "https://api.moonshot.cn/v1" elif llm_provider == "ollama": # api_key = config.app.get("openai_api_key") - api_key = "ollama" # any string works but you are required to have one + api_key = "ollama" # any string works but you are required to have one model_name = config.app.get("ollama_model_name") base_url = config.app.get("ollama_base_url", "") if not base_url: @@ -78,37 +78,38 @@ def _generate_response(prompt: str) -> str: return content.replace("\n", "") if llm_provider == "gemini": + import google.generativeai as genai genai.configure(api_key=api_key) generation_config = { - "temperature": 0.5, - "top_p": 1, - "top_k": 1, - "max_output_tokens": 2048, + "temperature": 0.5, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, } safety_settings = [ - { - "category": "HARM_CATEGORY_HARASSMENT", - "threshold": "BLOCK_ONLY_HIGH" - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "threshold": "BLOCK_ONLY_HIGH" - }, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_ONLY_HIGH" - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_ONLY_HIGH" - }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH" + }, ] model = genai.GenerativeModel(model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings) + generation_config=generation_config, + safety_settings=safety_settings) convo = model.start_chat(history=[]) diff --git a/webui/Main.py b/webui/Main.py index 5bf07ef..a86458f 100644 --- a/webui/Main.py +++ b/webui/Main.py @@ -143,7 +143,7 @@ def tr(key): return loc.get("Translation", {}).get(key, key) -with st.expander(tr("Basic Settings"), expanded=True): +with st.expander(tr("Basic Settings"), expanded=False): config_panels = st.columns(3) left_config_panel = config_panels[0] middle_config_panel = config_panels[1]