diff --git a/README.md b/README.md index 4733920..03c87bb 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ - [x] 支持 **背景音乐**,随机或者指定音乐文件,可设置`背景音乐音量` - [x] 视频素材来源 **高清**,而且 **无版权**,也可以使用自己的 **本地素材** - [x] 支持 **OpenAI**、**Moonshot**、**Azure**、**gpt4free**、**one-api**、**通义千问**、**Google Gemini**、**Ollama**、 - **DeepSeek** 等多种模型接入 + **DeepSeek**、 **文心一言** 等多种模型接入 - 中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商(国内可直接访问,不需要VPN。注册就送额度,基本够用) ### 后期计划 📅 diff --git a/app/services/llm.py b/app/services/llm.py index ac05137..40fe707 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -72,6 +72,13 @@ def _generate_response(prompt: str) -> str: base_url = config.app.get("deepseek_base_url") if not base_url: base_url = "https://api.deepseek.com" + elif llm_provider == "ernie": + api_key = config.app.get("ernie_api_key") + secret_key = config.app.get("ernie_secret_key") + base_url = config.app.get("ernie_base_url") + model_name = "***" + if not secret_key: + raise ValueError(f"{llm_provider}: secret_key is not set, please set it in the config.toml file.") else: raise ValueError("llm_provider is not set, please set it in the config.toml file.") @@ -165,6 +172,34 @@ def _generate_response(prompt: str) -> str: logger.info(result) return result["result"]["response"] + if llm_provider == "ernie": + import requests + params = {"grant_type": "client_credentials", "client_id": api_key, "client_secret": secret_key} + access_token = requests.post("https://aip.baidubce.com/oauth/2.0/token", params=params).json().get( + "access_token") + url = f"{base_url}?access_token={access_token}" + + payload = json.dumps({ + "messages": [ + { + "role": "user", + "content": prompt + } + ], + "temperature": 0.5, + "top_p": 0.8, + "penalty_score": 1, + "disable_search": False, + "enable_citation": False, + "response_format": "text" + }) + headers = { + 'Content-Type': 'application/json' + } + + response = requests.request("POST", url, headers=headers, data=payload).json() + return response.get("result") + if llm_provider == "azure": client = AzureOpenAI( api_key=api_key, @@ -239,7 +274,7 @@ Generate a script for a video, depending on the subject of the video. selected_paragraphs = paragraphs[:paragraph_number] # Join the selected paragraphs into a single string - return "\n\n".join(selected_paragraphs) + return "\n\n".join(paragraphs) for i in range(_max_retries): try: diff --git a/webui/Main.py b/webui/Main.py index 2cfa998..69979e2 100644 --- a/webui/Main.py +++ b/webui/Main.py @@ -174,6 +174,10 @@ if not config.app.get("hide_config", False): st.session_state['ui_language'] = code config.ui['language'] = code + # 是否禁用日志显示 + hide_log = st.checkbox(tr("Hide Log"), value=config.app.get("hide_log", False)) + config.ui['hide_log'] = hide_log + with middle_config_panel: # openai # moonshot (月之暗面) @@ -184,7 +188,7 @@ if not config.app.get("hide_config", False): # gemini # ollama llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'DeepSeek', 'Gemini', 'Ollama', 'G4f', 'OneAPI', - "Cloudflare"] + "Cloudflare", "ERNIE"] saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() saved_llm_provider_index = 0 for i, provider in enumerate(llm_providers): @@ -198,6 +202,7 @@ if not config.app.get("hide_config", False): config.app["llm_provider"] = llm_provider llm_api_key = config.app.get(f"{llm_provider}_api_key", "") + llm_secret_key = config.app.get(f"{llm_provider}_secret_key", "") # only for baidu ernie llm_base_url = config.app.get(f"{llm_provider}_base_url", "") llm_model_name = config.app.get(f"{llm_provider}_model_name", "") llm_account_id = config.app.get(f"{llm_provider}_account_id", "") @@ -300,6 +305,15 @@ if not config.app.get("hide_config", False): - **Model Name**: 固定为 deepseek-chat """ + if llm_provider == 'ernie': + with llm_helper: + tips = """ + ##### 百度文心一言 配置说明 + - **API Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application) + - **Secret Key**: [点击到官网申请](https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application) + - **Base Url**: 填写 **请求地址** [点击查看文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11#%E8%AF%B7%E6%B1%82%E8%AF%B4%E6%98%8E) + """ + if tips and config.ui['language'] == 'zh': st.warning( "中国用户建议使用 **DeepSeek** 或 **Moonshot** 作为大模型提供商\n- 国内可直接访问,不需要VPN \n- 注册就送额度,基本够用") @@ -307,7 +321,9 @@ if not config.app.get("hide_config", False): st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password") st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url) - st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name) + st_llm_model_name = "" + if llm_provider != 'ernie': + st.text_input(tr("Model Name"), value=llm_model_name) if st_llm_api_key: config.app[f"{llm_provider}_api_key"] = st_llm_api_key @@ -315,6 +331,9 @@ if not config.app.get("hide_config", False): config.app[f"{llm_provider}_base_url"] = st_llm_base_url if st_llm_model_name: config.app[f"{llm_provider}_model_name"] = st_llm_model_name + if llm_provider == 'ernie': + st_llm_secret_key = st.text_input(tr("Secret Key"), value=llm_secret_key, type="password") + config.app[f"{llm_provider}_secret_key"] = st_llm_secret_key if llm_provider == 'cloudflare': st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id) @@ -622,6 +641,8 @@ if start_button: def log_received(msg): + if config.ui['hide_log']: + return with log_container: log_records.append(msg) st.code("\n".join(log_records)) diff --git a/webui/i18n/en.json b/webui/i18n/en.json index a67edc2..3ca37ca 100644 --- a/webui/i18n/en.json +++ b/webui/i18n/en.json @@ -73,6 +73,7 @@ "Play Voice": "Play Voice", "Voice Example": "This is an example text for testing speech synthesis", "Synthesizing Voice": "Synthesizing voice, please wait...", - "TTS Provider": "Select the voice synthesis provider" + "TTS Provider": "Select the voice synthesis provider", + "Hide Log": "Hide Log" } } \ No newline at end of file diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 0798b03..4d560da 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -73,6 +73,7 @@ "Play Voice": "试听语音合成", "Voice Example": "这是一段测试语音合成的示例文本", "Synthesizing Voice": "语音合成中,请稍候...", - "TTS Provider": "语音合成提供商" + "TTS Provider": "语音合成提供商", + "Hide Log": "隐藏日志" } } \ No newline at end of file