diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index cc3ca3b..0430707 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -1,11 +1,12 @@ -import os import glob +import os import pathlib import shutil +from typing import Union -from fastapi import Request, Depends, Path, BackgroundTasks, UploadFile -from fastapi.responses import FileResponse, StreamingResponse +from fastapi import BackgroundTasks, Depends, Path, Request, UploadFile from fastapi.params import File +from fastapi.responses import FileResponse, StreamingResponse from loguru import logger from app.config import config @@ -14,10 +15,19 @@ from app.controllers.manager.memory_manager import InMemoryTaskManager from app.controllers.manager.redis_manager import RedisTaskManager from app.controllers.v1.base import new_router from app.models.exception import HttpException -from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \ - BgmUploadResponse, BgmRetrieveResponse, TaskDeletionResponse -from app.services import task as tm +from app.models.schema import ( + AudioRequest, + BgmRetrieveResponse, + BgmUploadResponse, + SubtitleRequest, + TaskDeletionResponse, + TaskQueryRequest, + TaskQueryResponse, + TaskResponse, + TaskVideoRequest, +) from app.services import state as sm +from app.services import task as tm from app.utils import utils # 认证依赖项 @@ -34,48 +44,65 @@ _max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" # 根据配置选择合适的任务管理器 if _enable_redis: - task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url) + task_manager = RedisTaskManager( + max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url + ) else: task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) -# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video") -# async def create_video_test(request: Request, body: TaskVideoRequest): -# task_id = utils.get_uuid() -# request_id = base.get_task_id(request) -# try: -# task = { -# "task_id": task_id, -# "request_id": request_id, -# "params": body.dict(), -# } -# task_manager.add_task(tm.start_test, task_id=task_id, params=body) -# return utils.get_response(200, task) -# except ValueError as e: -# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") - @router.post("/videos", response_model=TaskResponse, summary="Generate a short video") -def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest): +def create_video( + background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest +): + return create_task(request, body, stop_at="video") + + +@router.post("/subtitle", response_model=TaskResponse, summary="Generate subtitle only") +def create_subtitle( + background_tasks: BackgroundTasks, request: Request, body: SubtitleRequest +): + return create_task(request, body, stop_at="subtitle") + + +@router.post("/audio", response_model=TaskResponse, summary="Generate audio only") +def create_audio( + background_tasks: BackgroundTasks, request: Request, body: AudioRequest +): + return create_task(request, body, stop_at="audio") + + +def create_task( + request: Request, + body: Union[TaskVideoRequest, SubtitleRequest, AudioRequest], + stop_at: str, +): task_id = utils.get_uuid() request_id = base.get_task_id(request) try: task = { "task_id": task_id, "request_id": request_id, - "params": body.dict(), + "params": body.model_dump(), } sm.state.update_task(task_id) - # background_tasks.add_task(tm.start, task_id=task_id, params=body) - task_manager.add_task(tm.start, task_id=task_id, params=body) - logger.success(f"video created: {utils.to_json(task)}") + task_manager.add_task(tm.start, task_id=task_id, params=body, stop_at=stop_at) + logger.success(f"Task created: {utils.to_json(task)}") return utils.get_response(200, task) except ValueError as e: - raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") + raise HttpException( + task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}" + ) -@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status") -def get_task(request: Request, task_id: str = Path(..., description="Task ID"), - query: TaskQueryRequest = Depends()): +@router.get( + "/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status" +) +def get_task( + request: Request, + task_id: str = Path(..., description="Task ID"), + query: TaskQueryRequest = Depends(), +): endpoint = config.app.get("endpoint", "") if not endpoint: endpoint = str(request.base_url) @@ -108,10 +135,16 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"), task["combined_videos"] = urls return utils.get_response(200, task) - raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") + raise HttpException( + task_id=task_id, status_code=404, message=f"{request_id}: task not found" + ) -@router.delete("/tasks/{task_id}", response_model=TaskDeletionResponse, summary="Delete a generated short video task") +@router.delete( + "/tasks/{task_id}", + response_model=TaskDeletionResponse, + summary="Delete a generated short video task", +) def delete_video(request: Request, task_id: str = Path(..., description="Task ID")): request_id = base.get_task_id(request) task = sm.state.get_task(task_id) @@ -125,32 +158,40 @@ def delete_video(request: Request, task_id: str = Path(..., description="Task ID logger.success(f"video deleted: {utils.to_json(task)}") return utils.get_response(200) - raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found") + raise HttpException( + task_id=task_id, status_code=404, message=f"{request_id}: task not found" + ) -@router.get("/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files") +@router.get( + "/musics", response_model=BgmRetrieveResponse, summary="Retrieve local BGM files" +) def get_bgm_list(request: Request): suffix = "*.mp3" song_dir = utils.song_dir() files = glob.glob(os.path.join(song_dir, suffix)) bgm_list = [] for file in files: - bgm_list.append({ - "name": os.path.basename(file), - "size": os.path.getsize(file), - "file": file, - }) - response = { - "files": bgm_list - } + bgm_list.append( + { + "name": os.path.basename(file), + "size": os.path.getsize(file), + "file": file, + } + ) + response = {"files": bgm_list} return utils.get_response(200, response) -@router.post("/musics", response_model=BgmUploadResponse, summary="Upload the BGM file to the songs directory") +@router.post( + "/musics", + response_model=BgmUploadResponse, + summary="Upload the BGM file to the songs directory", +) def upload_bgm_file(request: Request, file: UploadFile = File(...)): request_id = base.get_task_id(request) # check file ext - if file.filename.endswith('mp3'): + if file.filename.endswith("mp3"): song_dir = utils.song_dir() save_path = os.path.join(song_dir, file.filename) # save file @@ -158,26 +199,26 @@ def upload_bgm_file(request: Request, file: UploadFile = File(...)): # If the file already exists, it will be overwritten file.file.seek(0) buffer.write(file.file.read()) - response = { - "file": save_path - } + response = {"file": save_path} return utils.get_response(200, response) - raise HttpException('', status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded") + raise HttpException( + "", status_code=400, message=f"{request_id}: Only *.mp3 files can be uploaded" + ) @router.get("/stream/{file_path:path}") async def stream_video(request: Request, file_path: str): tasks_dir = utils.task_dir() video_path = os.path.join(tasks_dir, file_path) - range_header = request.headers.get('Range') + range_header = request.headers.get("Range") video_size = os.path.getsize(video_path) start, end = 0, video_size - 1 length = video_size if range_header: - range_ = range_header.split('bytes=')[1] - start, end = [int(part) if part else None for part in range_.split('-')] + range_ = range_header.split("bytes=")[1] + start, end = [int(part) if part else None for part in range_.split("-")] if start is None: start = video_size - end end = video_size - 1 @@ -186,7 +227,7 @@ async def stream_video(request: Request, file_path: str): length = end - start + 1 def file_iterator(file_path, offset=0, bytes_to_read=None): - with open(file_path, 'rb') as f: + with open(file_path, "rb") as f: f.seek(offset, os.SEEK_SET) remaining = bytes_to_read or video_size while remaining > 0: @@ -197,10 +238,12 @@ async def stream_video(request: Request, file_path: str): remaining -= len(data) yield data - response = StreamingResponse(file_iterator(video_path, start, length), media_type='video/mp4') - response.headers['Content-Range'] = f'bytes {start}-{end}/{video_size}' - response.headers['Accept-Ranges'] = 'bytes' - response.headers['Content-Length'] = str(length) + response = StreamingResponse( + file_iterator(video_path, start, length), media_type="video/mp4" + ) + response.headers["Content-Range"] = f"bytes {start}-{end}/{video_size}" + response.headers["Accept-Ranges"] = "bytes" + response.headers["Content-Length"] = str(length) response.status_code = 206 # Partial Content return response @@ -219,8 +262,10 @@ async def download_video(_: Request, file_path: str): file_path = pathlib.Path(video_path) filename = file_path.stem extension = file_path.suffix - headers = { - "Content-Disposition": f"attachment; filename={filename}{extension}" - } - return FileResponse(path=video_path, headers=headers, filename=f"{filename}{extension}", - media_type=f'video/{extension[1:]}') + headers = {"Content-Disposition": f"attachment; filename={filename}{extension}"} + return FileResponse( + path=video_path, + headers=headers, + filename=f"{filename}{extension}", + media_type=f"video/{extension[1:]}", + ) diff --git a/app/models/schema.py b/app/models/schema.py index c45d7d9..6ecb63a 100644 --- a/app/models/schema.py +++ b/app/models/schema.py @@ -1,12 +1,16 @@ +import warnings from enum import Enum -from typing import Any, Optional, List +from typing import Any, List, Optional import pydantic from pydantic import BaseModel -import warnings # 忽略 Pydantic 的特定警告 -warnings.filterwarnings("ignore", category=UserWarning, message="Field name.*shadows an attribute in parent.*") +warnings.filterwarnings( + "ignore", + category=UserWarning, + message="Field name.*shadows an attribute in parent.*", +) class VideoConcatMode(str, Enum): @@ -61,7 +65,6 @@ class MaterialInfo: # # "male-zh-TW-YunJheNeural", # # # en-US -# # "female-en-US-AnaNeural", # "female-en-US-AriaNeural", # "female-en-US-AvaNeural", @@ -93,6 +96,7 @@ class VideoParams(BaseModel): "stroke_width": 1.5 } """ + video_subject: str video_script: str = "" # 用于生成视频的脚本 video_terms: Optional[str | list] = None # 用于生成视频的关键词 @@ -126,6 +130,38 @@ class VideoParams(BaseModel): paragraph_number: Optional[int] = 1 +class SubtitleRequest(BaseModel): + video_script: str + video_language: Optional[str] = "" + voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female" + voice_volume: Optional[float] = 1.0 + voice_rate: Optional[float] = 1.2 + bgm_type: Optional[str] = "random" + bgm_file: Optional[str] = "" + bgm_volume: Optional[float] = 0.2 + subtitle_position: Optional[str] = "bottom" + font_name: Optional[str] = "STHeitiMedium.ttc" + text_fore_color: Optional[str] = "#FFFFFF" + text_background_color: Optional[str] = "transparent" + font_size: int = 60 + stroke_color: Optional[str] = "#000000" + stroke_width: float = 1.5 + video_source: Optional[str] = "local" + subtitle_enabled: Optional[str] = "true" + + +class AudioRequest(BaseModel): + video_script: str + video_language: Optional[str] = "" + voice_name: Optional[str] = "zh-CN-XiaoxiaoNeural-Female" + voice_volume: Optional[float] = 1.0 + voice_rate: Optional[float] = 1.2 + bgm_type: Optional[str] = "random" + bgm_file: Optional[str] = "" + bgm_volume: Optional[float] = 0.2 + video_source: Optional[str] = "local" + + class VideoScriptParams: """ { @@ -134,6 +170,7 @@ class VideoScriptParams: "paragraph_number": 1 } """ + video_subject: Optional[str] = "春天的花海" video_language: Optional[str] = "" paragraph_number: Optional[int] = 1 @@ -147,14 +184,17 @@ class VideoTermsParams: "amount": 5 } """ + video_subject: Optional[str] = "春天的花海" - video_script: Optional[str] = "春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……" + video_script: Optional[str] = ( + "春天的花海,如诗如画般展现在眼前。万物复苏的季节里,大地披上了一袭绚丽多彩的盛装。金黄的迎春、粉嫩的樱花、洁白的梨花、艳丽的郁金香……" + ) amount: Optional[int] = 5 class BaseResponse(BaseModel): status: int = 200 - message: Optional[str] = 'success' + message: Optional[str] = "success" data: Any = None @@ -189,9 +229,7 @@ class TaskResponse(BaseResponse): "example": { "status": 200, "message": "success", - "data": { - "task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558" - } + "data": {"task_id": "6c85c8cc-a77a-42b9-bc30-947815aa0558"}, }, } @@ -210,8 +248,8 @@ class TaskQueryResponse(BaseResponse): ], "combined_videos": [ "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4" - ] - } + ], + }, }, } @@ -230,8 +268,8 @@ class TaskDeletionResponse(BaseResponse): ], "combined_videos": [ "http://127.0.0.1:8080/tasks/6c85c8cc-a77a-42b9-bc30-947815aa0558/combined-1.mp4" - ] - } + ], + }, }, } @@ -244,7 +282,7 @@ class VideoScriptResponse(BaseResponse): "message": "success", "data": { "video_script": "春天的花海,是大自然的一幅美丽画卷。在这个季节里,大地复苏,万物生长,花朵争相绽放,形成了一片五彩斑斓的花海..." - } + }, }, } @@ -255,9 +293,7 @@ class VideoTermsResponse(BaseResponse): "example": { "status": 200, "message": "success", - "data": { - "video_terms": ["sky", "tree"] - } + "data": {"video_terms": ["sky", "tree"]}, }, } @@ -273,10 +309,10 @@ class BgmRetrieveResponse(BaseResponse): { "name": "output013.mp3", "size": 1891269, - "file": "/MoneyPrinterTurbo/resource/songs/output013.mp3" + "file": "/MoneyPrinterTurbo/resource/songs/output013.mp3", } ] - } + }, }, } @@ -287,8 +323,6 @@ class BgmUploadResponse(BaseResponse): "example": { "status": 200, "message": "success", - "data": { - "file": "/MoneyPrinterTurbo/resource/songs/example.mp3" - } + "data": {"file": "/MoneyPrinterTurbo/resource/songs/example.mp3"}, }, } diff --git a/app/services/task.py b/app/services/task.py index cf396c4..c2c1048 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -7,58 +7,42 @@ from loguru import logger from app.config import config from app.models import const -from app.models.schema import VideoParams, VideoConcatMode -from app.services import llm, material, voice, video, subtitle +from app.models.schema import VideoConcatMode, VideoParams +from app.services import llm, material, subtitle, video, voice from app.services import state as sm from app.utils import utils -def start(task_id, params: VideoParams): - """ - { - "video_subject": "", - "video_aspect": "横屏 16:9(西瓜视频)", - "voice_name": "女生-晓晓", - "enable_bgm": false, - "font_name": "STHeitiMedium 黑体-中", - "text_color": "#FFFFFF", - "font_size": 60, - "stroke_color": "#000000", - "stroke_width": 1.5 - } - """ - logger.info(f"start task: {task_id}") - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) - - video_subject = params.video_subject - voice_name = voice.parse_voice_name(params.voice_name) - voice_rate = params.voice_rate - paragraph_number = params.paragraph_number - n_threads = params.n_threads - max_clip_duration = params.video_clip_duration - +def generate_script(task_id, params): logger.info("\n\n## generating video script") video_script = params.video_script.strip() if not video_script: - video_script = llm.generate_script(video_subject=video_subject, language=params.video_language, - paragraph_number=paragraph_number) + video_script = llm.generate_script( + video_subject=params.video_subject, + language=params.video_language, + paragraph_number=params.paragraph_number, + ) else: logger.debug(f"video script: \n{video_script}") if not video_script: sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error("failed to generate video script.") - return + return None - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) + return video_script + +def generate_terms(task_id, params, video_script): logger.info("\n\n## generating video terms") video_terms = params.video_terms if not video_terms: - video_terms = llm.generate_terms(video_subject=video_subject, video_script=video_script, amount=5) + video_terms = llm.generate_terms( + video_subject=params.video_subject, video_script=video_script, amount=5 + ) else: if isinstance(video_terms, str): - video_terms = [term.strip() for term in re.split(r'[,,]', video_terms)] + video_terms = [term.strip() for term in re.split(r"[,,]", video_terms)] elif isinstance(video_terms, list): video_terms = [term.strip() for term in video_terms] else: @@ -69,9 +53,13 @@ def start(task_id, params: VideoParams): if not video_terms: sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error("failed to generate video terms.") - return + return None - script_file = path.join(utils.task_dir(task_id), f"script.json") + return video_terms + + +def save_script_data(task_id, video_script, video_terms, params): + script_file = path.join(utils.task_dir(task_id), "script.json") script_data = { "script": video_script, "search_terms": video_terms, @@ -81,11 +69,16 @@ def start(task_id, params: VideoParams): with open(script_file, "w", encoding="utf-8") as f: f.write(utils.to_json(script_data)) - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) +def generate_audio(task_id, params, video_script): logger.info("\n\n## generating audio") - audio_file = path.join(utils.task_dir(task_id), f"audio.mp3") - sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_rate=voice_rate, voice_file=audio_file) + audio_file = path.join(utils.task_dir(task_id), "audio.mp3") + sub_maker = voice.tts( + text=video_script, + voice_name=voice.parse_voice_name(params.voice_name), + voice_rate=params.voice_rate, + voice_file=audio_file, + ) if sub_maker is None: sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) logger.error( @@ -94,86 +87,100 @@ def start(task_id, params: VideoParams): 2. check if the network is available. If you are in China, it is recommended to use a VPN and enable the global traffic mode. """.strip() ) - return + return None, None - audio_duration = voice.get_audio_duration(sub_maker) - audio_duration = math.ceil(audio_duration) + audio_duration = math.ceil(voice.get_audio_duration(sub_maker)) + return audio_file, audio_duration - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) - subtitle_path = "" - if params.subtitle_enabled: - subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt") - subtitle_provider = config.app.get("subtitle_provider", "").strip().lower() - logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}") - subtitle_fallback = False - if subtitle_provider == "edge": - voice.create_subtitle(text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path) - if not os.path.exists(subtitle_path): - subtitle_fallback = True - logger.warning("subtitle file not found, fallback to whisper") +def generate_subtitle(task_id, params, video_script, sub_maker, audio_file): + if not params.subtitle_enabled: + return "" - if subtitle_provider == "whisper" or subtitle_fallback: - subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path) - logger.info("\n\n## correcting subtitle") - subtitle.correct(subtitle_file=subtitle_path, video_script=video_script) + subtitle_path = path.join(utils.task_dir(task_id), "subtitle.srt") + subtitle_provider = config.app.get("subtitle_provider", "").strip().lower() + logger.info(f"\n\n## generating subtitle, provider: {subtitle_provider}") - subtitle_lines = subtitle.file_to_subtitles(subtitle_path) - if not subtitle_lines: - logger.warning(f"subtitle file is invalid: {subtitle_path}") - subtitle_path = "" + subtitle_fallback = False + if subtitle_provider == "edge": + voice.create_subtitle( + text=video_script, sub_maker=sub_maker, subtitle_file=subtitle_path + ) + if not os.path.exists(subtitle_path): + subtitle_fallback = True + logger.warning("subtitle file not found, fallback to whisper") - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) + if subtitle_provider == "whisper" or subtitle_fallback: + subtitle.create(audio_file=audio_file, subtitle_file=subtitle_path) + logger.info("\n\n## correcting subtitle") + subtitle.correct(subtitle_file=subtitle_path, video_script=video_script) - downloaded_videos = [] + subtitle_lines = subtitle.file_to_subtitles(subtitle_path) + if not subtitle_lines: + logger.warning(f"subtitle file is invalid: {subtitle_path}") + return "" + + return subtitle_path + + +def get_video_materials(task_id, params, video_terms, audio_duration): if params.video_source == "local": logger.info("\n\n## preprocess local materials") - materials = video.preprocess_video(materials=params.video_materials, clip_duration=max_clip_duration) - print(materials) - + materials = video.preprocess_video( + materials=params.video_materials, clip_duration=params.video_clip_duration + ) if not materials: sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) - logger.error("no valid materials found, please check the materials and try again.") - return - for material_info in materials: - print(material_info) - downloaded_videos.append(material_info.url) + logger.error( + "no valid materials found, please check the materials and try again." + ) + return None + return [material_info.url for material_info in materials] else: logger.info(f"\n\n## downloading videos from {params.video_source}") - downloaded_videos = material.download_videos(task_id=task_id, - search_terms=video_terms, - source=params.video_source, - video_aspect=params.video_aspect, - video_contact_mode=params.video_concat_mode, - audio_duration=audio_duration * params.video_count, - max_clip_duration=max_clip_duration, - ) - if not downloaded_videos: - sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) - logger.error( - "failed to download videos, maybe the network is not available. if you are in China, please use a VPN.") - return + downloaded_videos = material.download_videos( + task_id=task_id, + search_terms=video_terms, + source=params.video_source, + video_aspect=params.video_aspect, + video_contact_mode=params.video_concat_mode, + audio_duration=audio_duration * params.video_count, + max_clip_duration=params.video_clip_duration, + ) + if not downloaded_videos: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + logger.error( + "failed to download videos, maybe the network is not available. if you are in China, please use a VPN." + ) + return None + return downloaded_videos - sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) +def generate_final_videos( + task_id, params, downloaded_videos, audio_file, subtitle_path +): final_video_paths = [] combined_video_paths = [] - video_concat_mode = params.video_concat_mode - if params.video_count > 1: - video_concat_mode = VideoConcatMode.random + video_concat_mode = ( + params.video_concat_mode if params.video_count > 1 else VideoConcatMode.random + ) _progress = 50 for i in range(params.video_count): index = i + 1 - combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4") + combined_video_path = path.join( + utils.task_dir(task_id), f"combined-{index}.mp4" + ) logger.info(f"\n\n## combining video: {index} => {combined_video_path}") - video.combine_videos(combined_video_path=combined_video_path, - video_paths=downloaded_videos, - audio_file=audio_file, - video_aspect=params.video_aspect, - video_concat_mode=video_concat_mode, - max_clip_duration=max_clip_duration, - threads=n_threads) + video.combine_videos( + combined_video_path=combined_video_path, + video_paths=downloaded_videos, + audio_file=audio_file, + video_aspect=params.video_aspect, + video_concat_mode=video_concat_mode, + max_clip_duration=params.video_clip_duration, + threads=params.n_threads, + ) _progress += 50 / params.video_count / 2 sm.state.update_task(task_id, progress=_progress) @@ -181,13 +188,13 @@ def start(task_id, params: VideoParams): final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4") logger.info(f"\n\n## generating video: {index} => {final_video_path}") - # Put everything together - video.generate_video(video_path=combined_video_path, - audio_path=audio_file, - subtitle_path=subtitle_path, - output_file=final_video_path, - params=params, - ) + video.generate_video( + video_path=combined_video_path, + audio_path=audio_file, + subtitle_path=subtitle_path, + output_file=final_video_path, + params=params, + ) _progress += 50 / params.video_count / 2 sm.state.update_task(task_id, progress=_progress) @@ -195,16 +202,119 @@ def start(task_id, params: VideoParams): final_video_paths.append(final_video_path) combined_video_paths.append(combined_video_path) - logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.") + return final_video_paths, combined_video_paths + + +def start(task_id, params: VideoParams, stop_at: str = "video"): + logger.info(f"start task: {task_id}, stop_at: {stop_at}") + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5) + + # 1. Generate script + video_script = generate_script(task_id, params) + if not video_script: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + return + + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10) + + if stop_at == "script": + sm.state.update_task( + task_id, state=const.TASK_STATE_COMPLETE, progress=100, script=video_script + ) + return {"script": video_script} + + # 2. Generate terms + video_terms = "" + if params.video_source != "local": + video_terms = generate_terms(task_id, params, video_script) + if not video_terms: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + return + + save_script_data(task_id, video_script, video_terms, params) + + if stop_at == "terms": + sm.state.update_task( + task_id, state=const.TASK_STATE_COMPLETE, progress=100, terms=video_terms + ) + return {"script": video_script, "terms": video_terms} + + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20) + + # 3. Generate audio + audio_file, audio_duration = generate_audio(task_id, params, video_script) + if not audio_file: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + return + + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30) + + if stop_at == "audio": + sm.state.update_task( + task_id, + state=const.TASK_STATE_COMPLETE, + progress=100, + audio_file=audio_file, + ) + return {"audio_file": audio_file, "audio_duration": audio_duration} + + # 4. Generate subtitle + subtitle_path = generate_subtitle(task_id, params, video_script, None, audio_file) + + if stop_at == "subtitle": + sm.state.update_task( + task_id, + state=const.TASK_STATE_COMPLETE, + progress=100, + subtitle_path=subtitle_path, + ) + return {"subtitle_path": subtitle_path} + + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40) + + # 5. Get video materials + downloaded_videos = get_video_materials( + task_id, params, video_terms, audio_duration + ) + if not downloaded_videos: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + return + + if stop_at == "materials": + sm.state.update_task( + task_id, + state=const.TASK_STATE_COMPLETE, + progress=100, + materials=downloaded_videos, + ) + return {"materials": downloaded_videos} + + sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50) + + # 6. Generate final videos + final_video_paths, combined_video_paths = generate_final_videos( + task_id, params, downloaded_videos, audio_file, subtitle_path + ) + + if not final_video_paths: + sm.state.update_task(task_id, state=const.TASK_STATE_FAILED) + return + + logger.success( + f"task {task_id} finished, generated {len(final_video_paths)} videos." + ) kwargs = { "videos": final_video_paths, - "combined_videos": combined_video_paths + "combined_videos": combined_video_paths, + "script": video_script, + "terms": video_terms, + "audio_file": audio_file, + "audio_duration": audio_duration, + "subtitle_path": subtitle_path, + "materials": downloaded_videos, } - sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) + sm.state.update_task( + task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs + ) return kwargs - -# def start_test(task_id, params: VideoParams): -# print(f"start task {task_id} \n") -# time.sleep(5) -# print(f"task {task_id} finished \n")