mirror of
https://github.com/harry0703/MoneyPrinterTurbo.git
synced 2026-02-21 08:27:22 +08:00
complete task query interface
This commit is contained in:
parent
95bc24453f
commit
9283787681
@ -46,6 +46,10 @@ def get_application() -> FastAPI:
|
||||
|
||||
|
||||
app = get_application()
|
||||
|
||||
task_dir = utils.task_dir()
|
||||
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
|
||||
|
||||
public_dir = utils.public_dir()
|
||||
app.mount("/", StaticFiles(directory=public_dir, html=True), name="")
|
||||
|
||||
|
||||
@ -9,6 +9,7 @@ if not os.path.isfile(config_file):
|
||||
example_file = f"{root_dir}/config.example.toml"
|
||||
if os.path.isfile(example_file):
|
||||
import shutil
|
||||
|
||||
shutil.copyfile(example_file, config_file)
|
||||
logger.info(f"copy config.example.toml to config.toml")
|
||||
|
||||
@ -27,8 +28,9 @@ log_level = _cfg.get("log_level", "DEBUG")
|
||||
listen_host = _cfg.get("listen_host", "0.0.0.0")
|
||||
listen_port = _cfg.get("listen_port", 8080)
|
||||
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
|
||||
project_description = _cfg.get("project_description", "MoneyPrinterTurbo\n by 抖音-网旭哈瑞.AI")
|
||||
project_version = _cfg.get("project_version", "1.0.0")
|
||||
project_description = _cfg.get("project_description",
|
||||
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
|
||||
project_version = _cfg.get("project_version", "1.0.1")
|
||||
reload_debug = False
|
||||
|
||||
imagemagick_path = app.get("imagemagick_path", "")
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
from os import path
|
||||
|
||||
from fastapi import Request, Depends, Path
|
||||
from fastapi import Request, Depends, Path, BackgroundTasks
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.controllers import base
|
||||
from app.controllers.v1.base import new_router
|
||||
from app.models.exception import HttpException
|
||||
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest
|
||||
from app.services import task as tm
|
||||
from app.services import state as sm
|
||||
from app.utils import utils
|
||||
|
||||
# 认证依赖项
|
||||
@ -15,30 +15,43 @@ from app.utils import utils
|
||||
router = new_router()
|
||||
|
||||
|
||||
@router.post("/videos", response_model=TaskResponse, summary="使用主题来生成短视频")
|
||||
def create_video(request: Request, body: TaskVideoRequest):
|
||||
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
|
||||
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
|
||||
task_id = utils.get_uuid()
|
||||
request_id = base.get_task_id(request)
|
||||
try:
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"request_id": request_id,
|
||||
"params": body.dict(),
|
||||
}
|
||||
body_dict = body.dict()
|
||||
task.update(body_dict)
|
||||
result = tm.start(task_id=task_id, params=body)
|
||||
task["result"] = result
|
||||
sm.update_task(task_id)
|
||||
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||
logger.success(f"video created: {utils.to_json(task)}")
|
||||
return utils.get_response(200, task)
|
||||
except ValueError as e:
|
||||
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="查询任务状态")
|
||||
def get_task(request: Request, task_id: str = Path(..., description="任务ID"),
|
||||
query: TaskQueryRequest = Depends()):
|
||||
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
|
||||
def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
|
||||
query: TaskQueryRequest = Depends()):
|
||||
endpoint = config.app.get("endpoint", "")
|
||||
if not endpoint:
|
||||
endpoint = str(request.base_url)
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
request_id = base.get_task_id(request)
|
||||
data = query.dict()
|
||||
data["task_id"] = task_id
|
||||
raise HttpException(task_id=task_id, status_code=404,
|
||||
message=f"{request_id}: task not found", data=data)
|
||||
task = sm.get_task(task_id)
|
||||
if task:
|
||||
if "videos" in task:
|
||||
videos = task["videos"]
|
||||
task_dir = utils.task_dir()
|
||||
urls = []
|
||||
for v in videos:
|
||||
uri_path = v.replace(task_dir, "tasks")
|
||||
urls.append(f"{endpoint}/{uri_path}")
|
||||
task["videos"] = urls
|
||||
return utils.get_response(200, task)
|
||||
|
||||
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
punctuations = [
|
||||
PUNCTUATIONS = [
|
||||
"?", ",", ".", "、", ";", ":",
|
||||
"?", ",", "。", "、", ";", ":",
|
||||
]
|
||||
|
||||
TASK_STATE_FAILED = -1
|
||||
TASK_STATE_COMPLETE = 1
|
||||
TASK_STATE_PROCESSING = 4
|
||||
|
||||
@ -136,7 +136,6 @@ class TaskQueryRequest(BaseModel):
|
||||
class TaskResponse(BaseResponse):
|
||||
class TaskResponseData(BaseModel):
|
||||
task_id: str
|
||||
task_type: str = ""
|
||||
|
||||
data: TaskResponseData
|
||||
|
||||
|
||||
35
app/services/state.py
Normal file
35
app/services/state.py
Normal file
@ -0,0 +1,35 @@
|
||||
# State Management
|
||||
# This module is responsible for managing the state of the application.
|
||||
import math
|
||||
|
||||
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
|
||||
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
|
||||
|
||||
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
|
||||
# If your application is single-node, you can use memory to store the state.
|
||||
|
||||
from app.models import const
|
||||
from app.utils import utils
|
||||
|
||||
_tasks = {}
|
||||
|
||||
|
||||
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
|
||||
"""
|
||||
Set the state of the task.
|
||||
"""
|
||||
progress = int(progress)
|
||||
if progress > 100:
|
||||
progress = 100
|
||||
|
||||
_tasks[task_id] = {
|
||||
"state": state,
|
||||
"progress": progress,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
Get the state of the task.
|
||||
"""
|
||||
return _tasks.get(task_id, None)
|
||||
@ -6,8 +6,10 @@ from os import path
|
||||
from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.models import const
|
||||
from app.models.schema import VideoParams, VideoConcatMode
|
||||
from app.services import llm, material, voice, video, subtitle
|
||||
from app.services import state as sm
|
||||
from app.utils import utils
|
||||
|
||||
|
||||
@ -26,6 +28,8 @@ def start(task_id, params: VideoParams):
|
||||
}
|
||||
"""
|
||||
logger.info(f"start task: {task_id}")
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
|
||||
|
||||
video_subject = params.video_subject
|
||||
voice_name = voice.parse_voice_name(params.voice_name)
|
||||
paragraph_number = params.paragraph_number
|
||||
@ -40,6 +44,8 @@ def start(task_id, params: VideoParams):
|
||||
else:
|
||||
logger.debug(f"video script: \n{video_script}")
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
|
||||
|
||||
logger.info("\n\n## generating video terms")
|
||||
video_terms = params.video_terms
|
||||
if not video_terms:
|
||||
@ -63,10 +69,13 @@ def start(task_id, params: VideoParams):
|
||||
with open(script_file, "w", encoding="utf-8") as f:
|
||||
f.write(utils.to_json(script_data))
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
|
||||
|
||||
logger.info("\n\n## generating audio")
|
||||
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
|
||||
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
|
||||
if sub_maker is None:
|
||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
|
||||
return
|
||||
@ -74,6 +83,8 @@ def start(task_id, params: VideoParams):
|
||||
audio_duration = voice.get_audio_duration(sub_maker)
|
||||
audio_duration = math.ceil(audio_duration)
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
|
||||
|
||||
subtitle_path = ""
|
||||
if params.subtitle_enabled:
|
||||
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
|
||||
@ -101,6 +112,8 @@ def start(task_id, params: VideoParams):
|
||||
logger.warning(f"subtitle file is invalid: {subtitle_path}")
|
||||
subtitle_path = ""
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
|
||||
|
||||
logger.info("\n\n## downloading videos")
|
||||
downloaded_videos = material.download_videos(task_id=task_id,
|
||||
search_terms=video_terms,
|
||||
@ -110,15 +123,19 @@ def start(task_id, params: VideoParams):
|
||||
max_clip_duration=max_clip_duration,
|
||||
)
|
||||
if not downloaded_videos:
|
||||
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
|
||||
logger.error(
|
||||
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
|
||||
return
|
||||
|
||||
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
|
||||
|
||||
final_video_paths = []
|
||||
video_concat_mode = params.video_concat_mode
|
||||
if params.video_count > 1:
|
||||
video_concat_mode = VideoConcatMode.random
|
||||
|
||||
_progress = 50
|
||||
for i in range(params.video_count):
|
||||
index = i + 1
|
||||
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
|
||||
@ -131,6 +148,9 @@ def start(task_id, params: VideoParams):
|
||||
max_clip_duration=max_clip_duration,
|
||||
threads=n_threads)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.update_task(task_id, progress=_progress)
|
||||
|
||||
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
|
||||
|
||||
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
|
||||
@ -141,10 +161,16 @@ def start(task_id, params: VideoParams):
|
||||
output_file=final_video_path,
|
||||
params=params,
|
||||
)
|
||||
|
||||
_progress += 50 / params.video_count / 2
|
||||
sm.update_task(task_id, progress=_progress)
|
||||
|
||||
final_video_paths.append(final_video_path)
|
||||
|
||||
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
|
||||
|
||||
return {
|
||||
kwargs = {
|
||||
"videos": final_video_paths,
|
||||
}
|
||||
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||
return kwargs
|
||||
|
||||
@ -149,7 +149,7 @@ def text_to_srt(idx: int, msg: str, start_time: float, end_time: float) -> str:
|
||||
|
||||
|
||||
def str_contains_punctuation(word):
|
||||
for p in const.punctuations:
|
||||
for p in const.PUNCTUATIONS:
|
||||
if p in word:
|
||||
return True
|
||||
return False
|
||||
@ -159,7 +159,7 @@ def split_string_by_punctuations(s):
|
||||
result = []
|
||||
txt = ""
|
||||
for char in s:
|
||||
if char not in const.punctuations:
|
||||
if char not in const.PUNCTUATIONS:
|
||||
txt += char
|
||||
else:
|
||||
result.append(txt.strip())
|
||||
|
||||
@ -97,6 +97,20 @@
|
||||
# ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe"
|
||||
#########################################################################################
|
||||
|
||||
# 当视频生成成功后,API服务提供的视频下载接入点,默认为当前服务的地址和监听端口
|
||||
# 比如 http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
|
||||
# 如果你需要使用域名对外提供服务(一般会用nginx做代理),则可以设置为你的域名
|
||||
# 比如 https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
|
||||
# endpoint="https://xxxx.com"
|
||||
|
||||
# When the video is successfully generated, the API service provides a download endpoint for the video, defaulting to the service's current address and listening port.
|
||||
# For example, http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
|
||||
# If you need to provide the service externally using a domain name (usually done with nginx as a proxy), you can set it to your domain name.
|
||||
# For example, https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
|
||||
# endpoint="https://xxxx.com"
|
||||
endpoint=""
|
||||
|
||||
|
||||
[whisper]
|
||||
# Only effective when subtitle_provider is "whisper"
|
||||
|
||||
|
||||
19
resource/public/index.html
Normal file
19
resource/public/index.html
Normal file
@ -0,0 +1,19 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>MoneyPrinterTurbo</title>
|
||||
</head>
|
||||
<body>
|
||||
<h1>MoneyPrinterTurbo</h1>
|
||||
<a href="https://github.com/harry0703/MoneyPrinterTurbo">https://github.com/harry0703/MoneyPrinterTurbo</a>
|
||||
<p>
|
||||
只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。
|
||||
</p>
|
||||
|
||||
<p>
|
||||
Simply provide a topic or keyword for a video, and it will automatically generate the video copy, video materials,
|
||||
video subtitles, and video background music before synthesizing a high-definition short video.
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
Loading…
Reference in New Issue
Block a user