From 2d8cd23fe78ad960c83fca7dfdce958b7051be53 Mon Sep 17 00:00:00 2001 From: yyhhyyyyyy <95077259+yyhhyyyyyy@users.noreply.github.com> Date: Thu, 12 Dec 2024 14:29:14 +0800 Subject: [PATCH] =?UTF-8?q?=F0=9F=90=9B=20fix:=20fix=20the=20LLM=20logic?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/services/llm.py | 457 ++++++++++++++++++++++--------------------- app/services/task.py | 2 +- webui/Main.py | 60 +++--- 3 files changed, 270 insertions(+), 249 deletions(-) diff --git a/app/services/llm.py b/app/services/llm.py index b02a68d..5b3baa7 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -3,6 +3,7 @@ import logging import re from typing import List +import g4f from loguru import logger from openai import AzureOpenAI, OpenAI from openai.types.chat import ChatCompletion @@ -13,243 +14,244 @@ _max_retries = 5 def _generate_response(prompt: str) -> str: - content = "" - llm_provider = config.app.get("llm_provider", "openai") - logger.info(f"llm provider: {llm_provider}") - if llm_provider == "g4f": - model_name = config.app.get("g4f_model_name", "") - if not model_name: - model_name = "gpt-3.5-turbo-16k-0613" - import g4f - - content = g4f.ChatCompletion.create( - model=model_name, - messages=[{"role": "user", "content": prompt}], - ) - else: - api_version = "" # for azure - if llm_provider == "moonshot": - api_key = config.app.get("moonshot_api_key") - model_name = config.app.get("moonshot_model_name") - base_url = "https://api.moonshot.cn/v1" - elif llm_provider == "ollama": - # api_key = config.app.get("openai_api_key") - api_key = "ollama" # any string works but you are required to have one - model_name = config.app.get("ollama_model_name") - base_url = config.app.get("ollama_base_url", "") - if not base_url: - base_url = "http://localhost:11434/v1" - elif llm_provider == "openai": - api_key = config.app.get("openai_api_key") - model_name = config.app.get("openai_model_name") - base_url = config.app.get("openai_base_url", "") - if not base_url: - base_url = "https://api.openai.com/v1" - elif llm_provider == "oneapi": - api_key = config.app.get("oneapi_api_key") - model_name = config.app.get("oneapi_model_name") - base_url = config.app.get("oneapi_base_url", "") - elif llm_provider == "azure": - api_key = config.app.get("azure_api_key") - model_name = config.app.get("azure_model_name") - base_url = config.app.get("azure_base_url", "") - api_version = config.app.get("azure_api_version", "2024-02-15-preview") - elif llm_provider == "gemini": - api_key = config.app.get("gemini_api_key") - model_name = config.app.get("gemini_model_name") - base_url = "***" - elif llm_provider == "qwen": - api_key = config.app.get("qwen_api_key") - model_name = config.app.get("qwen_model_name") - base_url = "***" - elif llm_provider == "cloudflare": - api_key = config.app.get("cloudflare_api_key") - model_name = config.app.get("cloudflare_model_name") - account_id = config.app.get("cloudflare_account_id") - base_url = "***" - elif llm_provider == "deepseek": - api_key = config.app.get("deepseek_api_key") - model_name = config.app.get("deepseek_model_name") - base_url = config.app.get("deepseek_base_url") - if not base_url: - base_url = "https://api.deepseek.com" - elif llm_provider == "ernie": - api_key = config.app.get("ernie_api_key") - secret_key = config.app.get("ernie_secret_key") - base_url = config.app.get("ernie_base_url") - model_name = "***" - if not secret_key: - raise ValueError( - f"{llm_provider}: secret_key is not set, please set it in the config.toml file." - ) + try: + content = "" + llm_provider = config.app.get("llm_provider", "openai") + logger.info(f"llm provider: {llm_provider}") + if llm_provider == "g4f": + model_name = config.app.get("g4f_model_name", "") + if not model_name: + model_name = "gpt-3.5-turbo-16k-0613" + content = g4f.ChatCompletion.create( + model=model_name, + messages=[{"role": "user", "content": prompt}], + ) else: - raise ValueError( - "llm_provider is not set, please set it in the config.toml file." - ) + api_version = "" # for azure + if llm_provider == "moonshot": + api_key = config.app.get("moonshot_api_key") + model_name = config.app.get("moonshot_model_name") + base_url = "https://api.moonshot.cn/v1" + elif llm_provider == "ollama": + # api_key = config.app.get("openai_api_key") + api_key = "ollama" # any string works but you are required to have one + model_name = config.app.get("ollama_model_name") + base_url = config.app.get("ollama_base_url", "") + if not base_url: + base_url = "http://localhost:11434/v1" + elif llm_provider == "openai": + api_key = config.app.get("openai_api_key") + model_name = config.app.get("openai_model_name") + base_url = config.app.get("openai_base_url", "") + if not base_url: + base_url = "https://api.openai.com/v1" + elif llm_provider == "oneapi": + api_key = config.app.get("oneapi_api_key") + model_name = config.app.get("oneapi_model_name") + base_url = config.app.get("oneapi_base_url", "") + elif llm_provider == "azure": + api_key = config.app.get("azure_api_key") + model_name = config.app.get("azure_model_name") + base_url = config.app.get("azure_base_url", "") + api_version = config.app.get("azure_api_version", "2024-02-15-preview") + elif llm_provider == "gemini": + api_key = config.app.get("gemini_api_key") + model_name = config.app.get("gemini_model_name") + base_url = "***" + elif llm_provider == "qwen": + api_key = config.app.get("qwen_api_key") + model_name = config.app.get("qwen_model_name") + base_url = "***" + elif llm_provider == "cloudflare": + api_key = config.app.get("cloudflare_api_key") + model_name = config.app.get("cloudflare_model_name") + account_id = config.app.get("cloudflare_account_id") + base_url = "***" + elif llm_provider == "deepseek": + api_key = config.app.get("deepseek_api_key") + model_name = config.app.get("deepseek_model_name") + base_url = config.app.get("deepseek_base_url") + if not base_url: + base_url = "https://api.deepseek.com" + elif llm_provider == "ernie": + api_key = config.app.get("ernie_api_key") + secret_key = config.app.get("ernie_secret_key") + base_url = config.app.get("ernie_base_url") + model_name = "***" + if not secret_key: + raise ValueError( + f"{llm_provider}: secret_key is not set, please set it in the config.toml file." + ) + else: + raise ValueError( + "llm_provider is not set, please set it in the config.toml file." + ) - if not api_key: - raise ValueError( - f"{llm_provider}: api_key is not set, please set it in the config.toml file." - ) - if not model_name: - raise ValueError( - f"{llm_provider}: model_name is not set, please set it in the config.toml file." - ) - if not base_url: - raise ValueError( - f"{llm_provider}: base_url is not set, please set it in the config.toml file." - ) + if not api_key: + raise ValueError( + f"{llm_provider}: api_key is not set, please set it in the config.toml file." + ) + if not model_name: + raise ValueError( + f"{llm_provider}: model_name is not set, please set it in the config.toml file." + ) + if not base_url: + raise ValueError( + f"{llm_provider}: base_url is not set, please set it in the config.toml file." + ) - if llm_provider == "qwen": - import dashscope - from dashscope.api_entities.dashscope_response import GenerationResponse + if llm_provider == "qwen": + import dashscope + from dashscope.api_entities.dashscope_response import GenerationResponse - dashscope.api_key = api_key - response = dashscope.Generation.call( + dashscope.api_key = api_key + response = dashscope.Generation.call( + model=model_name, messages=[{"role": "user", "content": prompt}] + ) + if response: + if isinstance(response, GenerationResponse): + status_code = response.status_code + if status_code != 200: + raise Exception( + f'[{llm_provider}] returned an error response: "{response}"' + ) + + content = response["output"]["text"] + return content.replace("\n", "") + else: + raise Exception( + f'[{llm_provider}] returned an invalid response: "{response}"' + ) + else: + raise Exception(f"[{llm_provider}] returned an empty response") + + if llm_provider == "gemini": + import google.generativeai as genai + + genai.configure(api_key=api_key, transport="rest") + + generation_config = { + "temperature": 0.5, + "top_p": 1, + "top_k": 1, + "max_output_tokens": 2048, + } + + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH", + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH", + }, + ] + + model = genai.GenerativeModel( + model_name=model_name, + generation_config=generation_config, + safety_settings=safety_settings, + ) + + try: + response = model.generate_content(prompt) + candidates = response.candidates + generated_text = candidates[0].content.parts[0].text + except (AttributeError, IndexError) as e: + print("Gemini Error:", e) + + return generated_text + + if llm_provider == "cloudflare": + import requests + + response = requests.post( + f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}", + headers={"Authorization": f"Bearer {api_key}"}, + json={ + "messages": [ + {"role": "system", "content": "You are a friendly assistant"}, + {"role": "user", "content": prompt}, + ] + }, + ) + result = response.json() + logger.info(result) + return result["result"]["response"] + + if llm_provider == "ernie": + import requests + + params = { + "grant_type": "client_credentials", + "client_id": api_key, + "client_secret": secret_key, + } + access_token = ( + requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params) + .json() + .get("access_token") + ) + url = f"{base_url}?access_token={access_token}" + + payload = json.dumps( + { + "messages": [{"role": "user", "content": prompt}], + "temperature": 0.5, + "top_p": 0.8, + "penalty_score": 1, + "disable_search": False, + "enable_citation": False, + "response_format": "text", + } + ) + headers = {"Content-Type": "application/json"} + + response = requests.request( + "POST", url, headers=headers, data=payload + ).json() + return response.get("result") + + if llm_provider == "azure": + client = AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=base_url, + ) + else: + client = OpenAI( + api_key=api_key, + base_url=base_url, + ) + + response = client.chat.completions.create( model=model_name, messages=[{"role": "user", "content": prompt}] ) if response: - if isinstance(response, GenerationResponse): - status_code = response.status_code - if status_code != 200: - raise Exception( - f'[{llm_provider}] returned an error response: "{response}"' - ) - - content = response["output"]["text"] - return content.replace("\n", "") + if isinstance(response, ChatCompletion): + content = response.choices[0].message.content else: raise Exception( - f'[{llm_provider}] returned an invalid response: "{response}"' + f'[{llm_provider}] returned an invalid response: "{response}", please check your network ' + f"connection and try again." ) - else: - raise Exception(f"[{llm_provider}] returned an empty response") - - if llm_provider == "gemini": - import google.generativeai as genai - - genai.configure(api_key=api_key, transport="rest") - - generation_config = { - "temperature": 0.5, - "top_p": 1, - "top_k": 1, - "max_output_tokens": 2048, - } - - safety_settings = [ - { - "category": "HARM_CATEGORY_HARASSMENT", - "threshold": "BLOCK_ONLY_HIGH", - }, - { - "category": "HARM_CATEGORY_HATE_SPEECH", - "threshold": "BLOCK_ONLY_HIGH", - }, - { - "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", - "threshold": "BLOCK_ONLY_HIGH", - }, - { - "category": "HARM_CATEGORY_DANGEROUS_CONTENT", - "threshold": "BLOCK_ONLY_HIGH", - }, - ] - - model = genai.GenerativeModel( - model_name=model_name, - generation_config=generation_config, - safety_settings=safety_settings, - ) - - try: - response = model.generate_content(prompt) - candidates = response.candidates - generated_text = candidates[0].content.parts[0].text - except (AttributeError, IndexError) as e: - print("Gemini Error:", e) - - return generated_text - - if llm_provider == "cloudflare": - import requests - - response = requests.post( - f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}", - headers={"Authorization": f"Bearer {api_key}"}, - json={ - "messages": [ - {"role": "system", "content": "You are a friendly assistant"}, - {"role": "user", "content": prompt}, - ] - }, - ) - result = response.json() - logger.info(result) - return result["result"]["response"] - - if llm_provider == "ernie": - import requests - - params = { - "grant_type": "client_credentials", - "client_id": api_key, - "client_secret": secret_key, - } - access_token = ( - requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params) - .json() - .get("access_token") - ) - url = f"{base_url}?access_token={access_token}" - - payload = json.dumps( - { - "messages": [{"role": "user", "content": prompt}], - "temperature": 0.5, - "top_p": 0.8, - "penalty_score": 1, - "disable_search": False, - "enable_citation": False, - "response_format": "text", - } - ) - headers = {"Content-Type": "application/json"} - - response = requests.request( - "POST", url, headers=headers, data=payload - ).json() - return response.get("result") - - if llm_provider == "azure": - client = AzureOpenAI( - api_key=api_key, - api_version=api_version, - azure_endpoint=base_url, - ) - else: - client = OpenAI( - api_key=api_key, - base_url=base_url, - ) - - response = client.chat.completions.create( - model=model_name, messages=[{"role": "user", "content": prompt}] - ) - if response: - if isinstance(response, ChatCompletion): - content = response.choices[0].message.content else: raise Exception( - f'[{llm_provider}] returned an invalid response: "{response}", please check your network ' - f"connection and try again." + f"[{llm_provider}] returned an empty response, please check your network connection and try again." ) - else: - raise Exception( - f"[{llm_provider}] returned an empty response, please check your network connection and try again." - ) - return content.replace("\n", "") + return content.replace("\n", "") + except Exception as e: + return f"Error: {str(e)}" def generate_script( @@ -319,8 +321,10 @@ Generate a script for a video, depending on the subject of the video. if i < _max_retries: logger.warning(f"failed to generate video script, trying again... {i + 1}") - - logger.success(f"completed: \n{final_script}") + if "Error: " in final_script: + logger.error(f"failed to generate video script: {final_script}") + else: + logger.success(f"completed: \n{final_script}") return final_script.strip() @@ -358,6 +362,9 @@ Please note that you must use English for generating video search terms; Chinese for i in range(_max_retries): try: response = _generate_response(prompt) + if "Error: " in response: + logger.error(f"failed to generate video script: {response}") + return response search_terms = json.loads(response) if not isinstance(search_terms, list) or not all( isinstance(term, str) for term in search_terms diff --git a/app/services/task.py b/app/services/task.py index 83562d1..e3d9eb5 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -214,7 +214,7 @@ def start(task_id, params: VideoParams, stop_at: str = "video"): # 1. Generate script video_script = generate_script(task_id, params) - if not video_script: + if not video_script or "Error: " in video_script: sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) return diff --git a/webui/Main.py b/webui/Main.py index 2dbe447..244c7e1 100644 --- a/webui/Main.py +++ b/webui/Main.py @@ -449,8 +449,12 @@ with left_panel: selected_index = st.selectbox( tr("Script Language"), index=0, - options=range(len(video_languages)), # 使用索引作为内部选项值 - format_func=lambda x: video_languages[x][0], # 显示给用户的是标签 + options=range( + len(video_languages) + ), # Use the index as the internal option value + format_func=lambda x: video_languages[x][ + 0 + ], # The label is displayed to the user ) params.video_language = video_languages[selected_index][1] @@ -462,9 +466,13 @@ with left_panel: video_subject=params.video_subject, language=params.video_language ) terms = llm.generate_terms(params.video_subject, script) - st.session_state["video_script"] = script - st.session_state["video_terms"] = ", ".join(terms) - + if "Error: " in script: + st.error(tr(script)) + elif "Error: " in terms: + st.error(tr(terms)) + else: + st.session_state["video_script"] = script + st.session_state["video_terms"] = ", ".join(terms) params.video_script = st.text_area( tr("Video Script"), value=st.session_state["video_script"], height=280 ) @@ -475,7 +483,10 @@ with left_panel: with st.spinner(tr("Generating Video Keywords")): terms = llm.generate_terms(params.video_subject, params.video_script) - st.session_state["video_terms"] = ", ".join(terms) + if "Error: " in terms: + st.error(tr(terms)) + else: + st.session_state["video_terms"] = ", ".join(terms) params.video_terms = st.text_area( tr("Video Keywords"), value=st.session_state["video_terms"] @@ -522,8 +533,12 @@ with middle_panel: selected_index = st.selectbox( tr("Video Concat Mode"), index=1, - options=range(len(video_concat_modes)), # 使用索引作为内部选项值 - format_func=lambda x: video_concat_modes[x][0], # 显示给用户的是标签 + options=range( + len(video_concat_modes) + ), # Use the index as the internal option value + format_func=lambda x: video_concat_modes[x][ + 0 + ], # The label is displayed to the user ) params.video_concat_mode = VideoConcatMode( video_concat_modes[selected_index][1] @@ -535,8 +550,12 @@ with middle_panel: ] selected_index = st.selectbox( tr("Video Ratio"), - options=range(len(video_aspect_ratios)), # 使用索引作为内部选项值 - format_func=lambda x: video_aspect_ratios[x][0], # 显示给用户的是标签 + options=range( + len(video_aspect_ratios) + ), # Use the index as the internal option value + format_func=lambda x: video_aspect_ratios[x][ + 0 + ], # The label is displayed to the user ) params.video_aspect = VideoAspect(video_aspect_ratios[selected_index][1]) @@ -648,13 +667,17 @@ with middle_panel: selected_index = st.selectbox( tr("Background Music"), index=1, - options=range(len(bgm_options)), # 使用索引作为内部选项值 - format_func=lambda x: bgm_options[x][0], # 显示给用户的是标签 + options=range( + len(bgm_options) + ), # Use the index as the internal option value + format_func=lambda x: bgm_options[x][ + 0 + ], # The label is displayed to the user ) - # 获取选择的背景音乐类型 + # Get the selected background music type params.bgm_type = bgm_options[selected_index][1] - # 根据选择显示或隐藏组件 + # Show or hide components based on the selection if params.bgm_type == "custom": custom_bgm_file = st.text_input(tr("Custom Background Music File")) if custom_bgm_file and os.path.exists(custom_bgm_file): @@ -733,15 +756,6 @@ if start_button: scroll_to_bottom() st.stop() - if ( - llm_provider != "g4f" - and llm_provider != "ollama" - and not config.app.get(f"{llm_provider}_api_key", "") - ): - st.error(tr("Please Enter the LLM API Key")) - scroll_to_bottom() - st.stop() - if params.video_source not in ["pexels", "pixabay", "local"]: st.error(tr("Please Select a Valid Video Source")) scroll_to_bottom()