From 05da4a37669a48c8e9b9298fb5501b09706d84b0 Mon Sep 17 00:00:00 2001 From: highkay Date: Thu, 11 Apr 2024 22:55:08 +0800 Subject: [PATCH] =?UTF-8?q?-=20=E5=A2=9E=E5=8A=A0Cloudflare=20workers=20ai?= =?UTF-8?q?=E4=BD=9C=E4=B8=BAllm=E5=90=8E=E7=AB=AF=20-=20=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E4=B8=80=E4=BA=9Bgitignore?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +++- app/services/llm.py | 21 +++++++++++++++++++++ webui/Main.py | 8 +++++++- webui/i18n/en.json | 1 + webui/i18n/zh.json | 1 + 5 files changed, 33 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index dd7e788..00b64c8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,6 @@ /app/utils/__pycache__/ /*/__pycache__/* .vscode -/**/.streamlit \ No newline at end of file +/**/.streamlit +__pycache__ +logs/ \ No newline at end of file diff --git a/app/services/llm.py b/app/services/llm.py index fca6bca..fda371f 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -57,6 +57,11 @@ def _generate_response(prompt: str) -> str: api_key = config.app.get("qwen_api_key") model_name = config.app.get("qwen_model_name") base_url = "***" + elif llm_provider == "cloudflare": + api_key = config.app.get("cloudflare_api_key") + model_name = config.app.get("cloudflare_model_name") + account_id = config.app.get("cloudflare_account_id") + base_url = "***" else: raise ValueError("llm_provider is not set, please set it in the config.toml file.") @@ -115,6 +120,22 @@ def _generate_response(prompt: str) -> str: convo.send_message(prompt) return convo.last.text + + if llm_provider == "cloudflare": + import requests + response = requests.post( + f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/{model_name}", + headers={"Authorization": f"Bearer {api_key}"}, + json={ + "messages": [ + {"role": "system", "content": "You are a friendly assistant"}, + {"role": "user", "content": prompt} + ] + } + ) + result = response.json() + logger.info(result) + return result["result"]["response"] if llm_provider == "azure": client = AzureOpenAI( diff --git a/webui/Main.py b/webui/Main.py index 2712b28..9631d39 100644 --- a/webui/Main.py +++ b/webui/Main.py @@ -175,7 +175,7 @@ with st.expander(tr("Basic Settings"), expanded=False): # qwen (通义千问) # gemini # ollama - llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Gemini', 'Ollama', 'G4f', 'OneAPI'] + llm_providers = ['OpenAI', 'Moonshot', 'Azure', 'Qwen', 'Gemini', 'Ollama', 'G4f', 'OneAPI', "Cloudflare"] saved_llm_provider = config.app.get("llm_provider", "OpenAI").lower() saved_llm_provider_index = 0 for i, provider in enumerate(llm_providers): @@ -190,6 +190,7 @@ with st.expander(tr("Basic Settings"), expanded=False): llm_api_key = config.app.get(f"{llm_provider}_api_key", "") llm_base_url = config.app.get(f"{llm_provider}_base_url", "") llm_model_name = config.app.get(f"{llm_provider}_model_name", "") + llm_account_id = config.app.get(f"{llm_provider}_account_id", "") st_llm_api_key = st.text_input(tr("API Key"), value=llm_api_key, type="password") st_llm_base_url = st.text_input(tr("Base Url"), value=llm_base_url) st_llm_model_name = st.text_input(tr("Model Name"), value=llm_model_name) @@ -200,6 +201,11 @@ with st.expander(tr("Basic Settings"), expanded=False): if st_llm_model_name: config.app[f"{llm_provider}_model_name"] = st_llm_model_name + if llm_provider == 'cloudflare': + st_llm_account_id = st.text_input(tr("Account ID"), value=llm_account_id) + if st_llm_account_id: + config.app[f"{llm_provider}_account_id"] = st_llm_account_id + config.save_config() with right_config_panel: diff --git a/webui/i18n/en.json b/webui/i18n/en.json index 47bd073..f9458a9 100644 --- a/webui/i18n/en.json +++ b/webui/i18n/en.json @@ -55,6 +55,7 @@ "LLM Provider": "LLM Provider", "API Key": "API Key (:red[Required])", "Base Url": "Base Url", + "Account ID": "Account ID (Get from Cloudflare dashboard)", "Model Name": "Model Name", "Please Enter the LLM API Key": "Please Enter the **LLM API Key**", "Please Enter the Pexels API Key": "Please Enter the **Pexels API Key**", diff --git a/webui/i18n/zh.json b/webui/i18n/zh.json index 7d71d05..7019c27 100644 --- a/webui/i18n/zh.json +++ b/webui/i18n/zh.json @@ -55,6 +55,7 @@ "LLM Provider": "大模型提供商", "API Key": "API Key (:red[必填,需要到大模型提供商的后台申请])", "Base Url": "Base Url (可选)", + "Account ID": "账户ID (Cloudflare的dash面板url中获取)", "Model Name": "模型名称 (:blue[需要到大模型提供商的后台确认被授权的模型名称])", "Please Enter the LLM API Key": "请先填写大模型 **API Key**", "Please Enter the Pexels API Key": "请先填写 **Pexels API Key**",