diff --git a/app/asgi.py b/app/asgi.py index 8a00b0d..440e4c0 100644 --- a/app/asgi.py +++ b/app/asgi.py @@ -46,6 +46,10 @@ def get_application() -> FastAPI: app = get_application() + +task_dir = utils.task_dir() +app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="") + public_dir = utils.public_dir() app.mount("/", StaticFiles(directory=public_dir, html=True), name="") diff --git a/app/config/config.py b/app/config/config.py index 14322bd..0c3120b 100644 --- a/app/config/config.py +++ b/app/config/config.py @@ -9,6 +9,7 @@ if not os.path.isfile(config_file): example_file = f"{root_dir}/config.example.toml" if os.path.isfile(example_file): import shutil + shutil.copyfile(example_file, config_file) logger.info(f"copy config.example.toml to config.toml") @@ -27,8 +28,9 @@ log_level = _cfg.get("log_level", "DEBUG") listen_host = _cfg.get("listen_host", "0.0.0.0") listen_port = _cfg.get("listen_port", 8080) project_name = _cfg.get("project_name", "MoneyPrinterTurbo") -project_description = _cfg.get("project_description", "MoneyPrinterTurbo\n by 抖音-网旭哈瑞.AI") -project_version = _cfg.get("project_version", "1.0.0") +project_description = _cfg.get("project_description", + "https://github.com/harry0703/MoneyPrinterTurbo") +project_version = _cfg.get("project_version", "1.0.1") reload_debug = False imagemagick_path = app.get("imagemagick_path", "") diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index 0f450ee..7823509 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -1,13 +1,13 @@ -from os import path - -from fastapi import Request, Depends, Path +from fastapi import Request, Depends, Path, BackgroundTasks from loguru import logger +from app.config import config from app.controllers import base from app.controllers.v1.base import new_router from app.models.exception import HttpException from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest from app.services import task as tm +from app.services import state as sm from app.utils import utils # 认证依赖项 @@ -15,30 +15,43 @@ from app.utils import utils router = new_router() -@router.post("/videos", response_model=TaskResponse, summary="使用主题来生成短视频") -def create_video(request: Request, body: TaskVideoRequest): +@router.post("/videos", response_model=TaskResponse, summary="Generate a short video") +def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest): task_id = utils.get_uuid() request_id = base.get_task_id(request) try: task = { "task_id": task_id, "request_id": request_id, + "params": body.dict(), } - body_dict = body.dict() - task.update(body_dict) - result = tm.start(task_id=task_id, params=body) - task["result"] = result + sm.update_task(task_id) + background_tasks.add_task(tm.start, task_id=task_id, params=body) logger.success(f"video created: {utils.to_json(task)}") return utils.get_response(200, task) except ValueError as e: raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") -@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="查询任务状态") -def get_task(request: Request, task_id: str = Path(..., description="任务ID"), - query: TaskQueryRequest = Depends()): +@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status") +def get_task(request: Request, task_id: str = Path(..., description="Task ID"), + query: TaskQueryRequest = Depends()): + endpoint = config.app.get("endpoint", "") + if not endpoint: + endpoint = str(request.base_url) + endpoint = endpoint.rstrip("/") + request_id = base.get_task_id(request) - data = query.dict() - data["task_id"] = task_id - raise HttpException(task_id=task_id, status_code=404, - message=f"{request_id}: task not found", data=data) + task = sm.get_task(task_id) + if task: + if "videos" in task: + videos = task["videos"] + task_dir = utils.task_dir() + urls = [] + for v in videos: + uri_path = v.replace(task_dir, "tasks") + urls.append(f"{endpoint}/{uri_path}") + task["videos"] = urls + return utils.get_response(200, task) + + raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") diff --git a/app/models/const.py b/app/models/const.py index 0ea3b76..2aed3e0 100644 --- a/app/models/const.py +++ b/app/models/const.py @@ -1,4 +1,8 @@ -punctuations = [ +PUNCTUATIONS = [ "?", ",", ".", "、", ";", ":", "?", ",", "。", "、", ";", ":", ] + +TASK_STATE_FAILED = -1 +TASK_STATE_COMPLETE = 1 +TASK_STATE_PROCESSING = 4 diff --git a/app/models/schema.py b/app/models/schema.py index d29a6b4..a4170b9 100644 --- a/app/models/schema.py +++ b/app/models/schema.py @@ -136,7 +136,6 @@ class TaskQueryRequest(BaseModel): class TaskResponse(BaseResponse): class TaskResponseData(BaseModel): task_id: str - task_type: str = "" data: TaskResponseData diff --git a/app/services/state.py b/app/services/state.py new file mode 100644 index 0000000..606a2c1 --- /dev/null +++ b/app/services/state.py @@ -0,0 +1,35 @@ +# State Management +# This module is responsible for managing the state of the application. +import math + +# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。 +# 如果你的应用程序是单机的,你可以使用内存来存储状态。 + +# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database. +# If your application is single-node, you can use memory to store the state. + +from app.models import const +from app.utils import utils + +_tasks = {} + + +def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs): + """ + Set the state of the task. + """ + progress = int(progress) + if progress > 100: + progress = 100 + + _tasks[task_id] = { + "state": state, + "progress": progress, + **kwargs, + } + +def get_task(task_id: str): + """ + Get the state of the task. + """ + return _tasks.get(task_id, None) diff --git a/app/services/task.py b/app/services/task.py index bf091a8..fb636df 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -6,8 +6,10 @@ from os import path from loguru import logger from app.config import config +from app.models import const from app.models.schema import VideoParams, VideoConcatMode from app.services import llm, material, voice, video, subtitle +from app.services import state as sm from app.utils import utils @@ -26,6 +28,8 @@ def start(task_id, params: VideoParams): } """ logger.info(f"start task: {task_id}") + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) + video_subject = params.video_subject voice_name = voice.parse_voice_name(params.voice_name) paragraph_number = params.paragraph_number @@ -40,6 +44,8 @@ def start(task_id, params: VideoParams): else: logger.debug(f"video script: \n{video_script}") + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) + logger.info("\n\n## generating video terms") video_terms = params.video_terms if not video_terms: @@ -63,10 +69,13 @@ def start(task_id, params: VideoParams): with open(script_file, "w", encoding="utf-8") as f: f.write(utils.to_json(script_data)) + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) + logger.info("\n\n## generating audio") audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file) if sub_maker is None: + sm.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error( "failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.") return @@ -74,6 +83,8 @@ def start(task_id, params: VideoParams): audio_duration = voice.get_audio_duration(sub_maker) audio_duration = math.ceil(audio_duration) + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) + subtitle_path = "" if params.subtitle_enabled: subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt") @@ -101,6 +112,8 @@ def start(task_id, params: VideoParams): logger.warning(f"subtitle file is invalid: {subtitle_path}") subtitle_path = "" + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) + logger.info("\n\n## downloading videos") downloaded_videos = material.download_videos(task_id=task_id, search_terms=video_terms, @@ -110,15 +123,19 @@ def start(task_id, params: VideoParams): max_clip_duration=max_clip_duration, ) if not downloaded_videos: + sm.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error( "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") return + sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) + final_video_paths = [] video_concat_mode = params.video_concat_mode if params.video_count > 1: video_concat_mode = VideoConcatMode.random + _progress = 50 for i in range(params.video_count): index = i + 1 combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4") @@ -131,6 +148,9 @@ def start(task_id, params: VideoParams): max_clip_duration=max_clip_duration, threads=n_threads) + _progress += 50 / params.video_count / 2 + sm.update_task(task_id, progress=_progress) + final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") logger.info(f"\n\n## generating video: {index} => {final_video_path}") @@ -141,10 +161,16 @@ def start(task_id, params: VideoParams): output_file=final_video_path, params=params, ) + + _progress += 50 / params.video_count / 2 + sm.update_task(task_id, progress=_progress) + final_video_paths.append(final_video_path) logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.") - return { + kwargs = { "videos": final_video_paths, } + sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) + return kwargs diff --git a/app/utils/utils.py b/app/utils/utils.py index 0dad08c..91a4433 100644 --- a/app/utils/utils.py +++ b/app/utils/utils.py @@ -149,7 +149,7 @@ def text_to_srt(idx: int, msg: str, start_time: float, end_time: float) -> str: def str_contains_punctuation(word): - for p in const.punctuations: + for p in const.PUNCTUATIONS: if p in word: return True return False @@ -159,7 +159,7 @@ def split_string_by_punctuations(s): result = [] txt = "" for char in s: - if char not in const.punctuations: + if char not in const.PUNCTUATIONS: txt += char else: result.append(txt.strip()) diff --git a/config.example.toml b/config.example.toml index 7aed6df..19d98f4 100644 --- a/config.example.toml +++ b/config.example.toml @@ -97,6 +97,20 @@ # ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe" ######################################################################################### + # 当视频生成成功后,API服务提供的视频下载接入点,默认为当前服务的地址和监听端口 + # 比如 http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4 + # 如果你需要使用域名对外提供服务(一般会用nginx做代理),则可以设置为你的域名 + # 比如 https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4 + # endpoint="https://xxxx.com" + + # When the video is successfully generated, the API service provides a download endpoint for the video, defaulting to the service's current address and listening port. + # For example, http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4 + # If you need to provide the service externally using a domain name (usually done with nginx as a proxy), you can set it to your domain name. + # For example, https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4 + # endpoint="https://xxxx.com" + endpoint="" + + [whisper] # Only effective when subtitle_provider is "whisper" diff --git a/resource/public/index.html b/resource/public/index.html new file mode 100644 index 0000000..45e8037 --- /dev/null +++ b/resource/public/index.html @@ -0,0 +1,19 @@ + + +
+ ++ 只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。 +
+ ++ Simply provide a topic or keyword for a video, and it will automatically generate the video copy, video materials, + video subtitles, and video background music before synthesizing a high-definition short video. +
+ + \ No newline at end of file