From d57434e0d31c8195dbcd3c86ff2763af96736cdf Mon Sep 17 00:00:00 2001 From: "kevin.zhang" Date: Wed, 10 Apr 2024 16:14:50 +0800 Subject: [PATCH 1/2] feat: add task deletion endpoint --- app/controllers/v1/video.py | 21 ++++++++++++++++++++- app/models/schema.py | 20 ++++++++++++++++++++ app/services/state.py | 7 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index 8b08026..be9fb04 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -1,5 +1,7 @@ import os import glob +import shutil + from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile from fastapi.params import File from loguru import logger @@ -9,7 +11,7 @@ from app.controllers import base from app.controllers.v1.base import new_router from app.models.exception import HttpException from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \ - BgmUploadResponse, BgmRetrieveResponse + BgmUploadResponse, BgmRetrieveResponse, TaskDeletionResponse from app.services import task as tm from app.services import state as sm from app.utils import utils @@ -75,6 +77,23 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"), raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") +@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task") +def create_video(request: Request, task_id: str = Path(..., description="Task ID")): + request_id = base.get_task_id(request) + task = sm.state.get_task(task_id) + if task: + tasks_dir = utils.task_dir() + current_task_dir = os.path.join(tasks_dir, task_id) + if os.path.exists(current_task_dir): + shutil.rmtree(current_task_dir) + + sm.state.delete_task(task_id) + logger.success(f"video deleted: {utils.to_json(task)}") + return utils.get_response(200, task) + + raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") + + @router.get("/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files") def get_bgm_list(request: Request): suffix = "*.mp3" diff --git a/app/models/schema.py b/app/models/schema.py index 94a9859..29cf732 100644 --- a/app/models/schema.py +++ b/app/models/schema.py @@ -206,6 +206,26 @@ class TaskQueryResponse(BaseResponse): } +class TaskDeletionResponse(BaseResponse): + class Config: + json_schema_extra = { + "example": { + "status": 200, + "message": "success", + "data": { + "state": 1, + "progress": 100, + "videos": [ + "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/final-1.mp4" + ], + "combined_videos": [ + "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4" + ] + } + }, + } + + class VideoScriptResponse(BaseResponse): class Config: json_schema_extra = { diff --git a/app/services/state.py b/app/services/state.py index 0aa95ef..1bb32e0 100644 --- a/app/services/state.py +++ b/app/services/state.py @@ -38,6 +38,10 @@ class MemoryState(BaseState): def get_task(self, task_id: str): return self._tasks.get(task_id, None) + def delete_task(self, task_id: str): + if task_id in self._tasks: + del self._tasks[task_id] + # Redis state management class RedisState(BaseState): @@ -67,6 +71,9 @@ class RedisState(BaseState): task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()} return task + def delete_task(self, task_id: str): + self._redis.delete(task_id) + @staticmethod def _convert_to_original_type(value): """ From 2a251ebbd178a8b677a03578100bc2ef75160b34 Mon Sep 17 00:00:00 2001 From: cpanel10x Date: Wed, 10 Apr 2024 15:47:39 +0700 Subject: [PATCH 2/2] fix LLM API Key with g4f config fix LLM API Key with g4f config --- webui/Main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/webui/Main.py b/webui/Main.py index ad32c1f..2712b28 100644 --- a/webui/Main.py +++ b/webui/Main.py @@ -396,7 +396,7 @@ if start_button: scroll_to_bottom() st.stop() - if not config.app.get(f"{llm_provider}_api_key", ""): + if llm_provider != 'g4f' and not config.app.get(f"{llm_provider}_api_key", ""): st.error(tr("Please Enter the LLM API Key")) scroll_to_bottom() st.stop()