diff --git a/app/services/gpt.py b/app/services/llm.py similarity index 77% rename from app/services/gpt.py rename to app/services/llm.py index 9bdde50..960dc9a 100644 --- a/app/services/gpt.py +++ b/app/services/llm.py @@ -1,32 +1,45 @@ import logging import re import json -import openai from typing import List from loguru import logger +from openai import OpenAI from app.config import config -openai_api_key = config.app.get("openai_api_key") -if not openai_api_key: - raise ValueError("openai_api_key is not set, please set it in the config.toml file.") - -openai_model_name = config.app.get("openai_model_name") -if not openai_model_name: - raise ValueError("openai_model_name is not set, please set it in the config.toml file.") - -openai_base_url = config.app.get("openai_base_url") - -openai.api_key = openai_api_key -openai_model_name = openai_model_name -if openai_base_url: - openai.base_url = openai_base_url - def _generate_response(prompt: str) -> str: - model_name = openai_model_name + llm_provider = config.app.get("llm_provider") + if llm_provider == "moonshot": + api_key = config.app.get("moonshot_api_key") + model_name = config.app.get("moonshot_model_name") + base_url = "https://api.moonshot.cn/v1" + elif llm_provider == "openai": + api_key = config.app.get("openai_api_key") + model_name = config.app.get("openai_model_name") + base_url = config.app.get("openai_base_url", "") + if not base_url: + base_url = "https://api.openai.com/v1" + elif llm_provider == "oneapi": + api_key = config.app.get("oneapi_api_key") + model_name = config.app.get("oneapi_model_name") + base_url = config.app.get("oneapi_base_url", "") + else: + raise ValueError("llm_provider is not set, please set it in the config.toml file.") - response = openai.chat.completions.create( + if not api_key: + raise ValueError(f"{llm_provider}: api_key is not set, please set it in the config.toml file.") + if not model_name: + raise ValueError(f"{llm_provider}: model_name is not set, please set it in the config.toml file.") + if not base_url: + raise ValueError(f"{llm_provider}: base_url is not set, please set it in the config.toml file.") + + client = OpenAI( + api_key=api_key, + base_url=base_url, + ) + + response = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}], ).choices[0].message.content diff --git a/app/services/task.py b/app/services/task.py index cbe2804..079556b 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -4,7 +4,7 @@ from loguru import logger from app.config import config from app.models.schema import VideoParams, VoiceNames -from app.services import gpt, material, voice, video, subtitle +from app.services import llm, material, voice, video, subtitle from app.utils import utils @@ -42,10 +42,10 @@ def start(task_id, params: VideoParams): n_threads = params.n_threads logger.info("\n\n## generating video script") - script = gpt.generate_script(video_subject=video_subject, language=language, paragraph_number=paragraph_number) + script = llm.generate_script(video_subject=video_subject, language=language, paragraph_number=paragraph_number) logger.info("\n\n## generating video terms") - search_terms = gpt.generate_terms(video_subject=video_subject, video_script=script, amount=5) + search_terms = llm.generate_terms(video_subject=video_subject, video_script=script, amount=5) script_file = path.join(utils.task_dir(task_id), f"script.json") script_data = { diff --git a/config.example.toml b/config.example.toml index c6373d5..7fc8ae5 100644 --- a/config.example.toml +++ b/config.example.toml @@ -5,12 +5,26 @@ # For example: pexels_api_keys = ["123456789","abcdefghi"] pexels_api_keys = [] + llm_provider="openai" # "openai" or "moonshot" or "oneapi" + # OpenAI API Key # Visit https://openai.com/api/ for details on obtaining an API key. openai_api_key = "" - openai_base_url="" + openai_base_url = "" openai_model_name = "gpt-4-turbo-preview" + # Moonshot API Key + # Visit https://platform.moonshot.cn/console/api-keys to get your API key. + moonshot_api_key="" + moonshot_base_url = "https://api.moonshot.cn/v1" + moonshot_model_name = "moonshot-v1-8k" + + # OneAPI API Key + # Visit https://github.com/songquanpeng/one-api to get your API key + oneapi_api_key="" + oneapi_base_url="" + oneapi_model_name="" + # Subtitle Provider, "edge" or "whisper" # If empty, the subtitle will not be generated subtitle_provider = "edge"