complete task query interface

This commit is contained in:
harry 2024-04-01 20:12:14 +08:00
parent 95bc24453f
commit 9283787681
10 changed files with 139 additions and 23 deletions

View File

@ -46,6 +46,10 @@ def get_application() -> FastAPI:
app = get_application()
task_dir = utils.task_dir()
app.mount("/tasks", StaticFiles(directory=task_dir, html=True, follow_symlink=True), name="")
public_dir = utils.public_dir()
app.mount("/", StaticFiles(directory=public_dir, html=True), name="")

View File

@ -9,6 +9,7 @@ if not os.path.isfile(config_file):
example_file = f"{root_dir}/config.example.toml"
if os.path.isfile(example_file):
import shutil
shutil.copyfile(example_file, config_file)
logger.info(f"copy config.example.toml to config.toml")
@ -27,8 +28,9 @@ log_level = _cfg.get("log_level", "DEBUG")
listen_host = _cfg.get("listen_host", "0.0.0.0")
listen_port = _cfg.get("listen_port", 8080)
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
project_description = _cfg.get("project_description", "MoneyPrinterTurbo\n by 抖音-网旭哈瑞.AI")
project_version = _cfg.get("project_version", "1.0.0")
project_description = _cfg.get("project_description",
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
project_version = _cfg.get("project_version", "1.0.1")
reload_debug = False
imagemagick_path = app.get("imagemagick_path", "")

View File

@ -1,13 +1,13 @@
from os import path
from fastapi import Request, Depends, Path
from fastapi import Request, Depends, Path, BackgroundTasks
from loguru import logger
from app.config import config
from app.controllers import base
from app.controllers.v1.base import new_router
from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest
from app.services import task as tm
from app.services import state as sm
from app.utils import utils
# 认证依赖项
@ -15,30 +15,43 @@ from app.utils import utils
router = new_router()
@router.post("/videos", response_model=TaskResponse, summary="使用主题来生成短视频")
def create_video(request: Request, body: TaskVideoRequest):
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
task_id = utils.get_uuid()
request_id = base.get_task_id(request)
try:
task = {
"task_id": task_id,
"request_id": request_id,
"params": body.dict(),
}
body_dict = body.dict()
task.update(body_dict)
result = tm.start(task_id=task_id, params=body)
task["result"] = result
sm.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
except ValueError as e:
raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="查询任务状态")
def get_task(request: Request, task_id: str = Path(..., description="任务ID"),
query: TaskQueryRequest = Depends()):
@router.get("/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status")
def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
query: TaskQueryRequest = Depends()):
endpoint = config.app.get("endpoint", "")
if not endpoint:
endpoint = str(request.base_url)
endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request)
data = query.dict()
data["task_id"] = task_id
raise HttpException(task_id=task_id, status_code=404,
message=f"{request_id}: task not found", data=data)
task = sm.get_task(task_id)
if task:
if "videos" in task:
videos = task["videos"]
task_dir = utils.task_dir()
urls = []
for v in videos:
uri_path = v.replace(task_dir, "tasks")
urls.append(f"{endpoint}/{uri_path}")
task["videos"] = urls
return utils.get_response(200, task)
raise HttpException(task_id=task_id, status_code=404, message=f"{request_id}: task not found")

View File

@ -1,4 +1,8 @@
punctuations = [
PUNCTUATIONS = [
"?", ",", ".", "", ";", ":",
"", "", "", "", "", "",
]
TASK_STATE_FAILED = -1
TASK_STATE_COMPLETE = 1
TASK_STATE_PROCESSING = 4

View File

@ -136,7 +136,6 @@ class TaskQueryRequest(BaseModel):
class TaskResponse(BaseResponse):
class TaskResponseData(BaseModel):
task_id: str
task_type: str = ""
data: TaskResponseData

35
app/services/state.py Normal file
View File

@ -0,0 +1,35 @@
# State Management
# This module is responsible for managing the state of the application.
import math
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
"""
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100
_tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(task_id: str):
"""
Get the state of the task.
"""
return _tasks.get(task_id, None)

View File

@ -6,8 +6,10 @@ from os import path
from loguru import logger
from app.config import config
from app.models import const
from app.models.schema import VideoParams, VideoConcatMode
from app.services import llm, material, voice, video, subtitle
from app.services import state as sm
from app.utils import utils
@ -26,6 +28,8 @@ def start(task_id, params: VideoParams):
}
"""
logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
paragraph_number = params.paragraph_number
@ -40,6 +44,8 @@ def start(task_id, params: VideoParams):
else:
logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms")
video_terms = params.video_terms
if not video_terms:
@ -63,10 +69,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return
@ -74,6 +83,8 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = ""
if params.subtitle_enabled:
subtitle_path = path.join(utils.task_dir(task_id), f"subtitle.srt")
@ -101,6 +112,8 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id,
search_terms=video_terms,
@ -110,15 +123,19 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
)
if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = []
video_concat_mode = params.video_concat_mode
if params.video_count > 1:
video_concat_mode = VideoConcatMode.random
_progress = 50
for i in range(params.video_count):
index = i + 1
combined_video_path = path.join(utils.task_dir(task_id), f"combined-{index}.mp4")
@ -131,6 +148,9 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
threads=n_threads)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
logger.info(f"\n\n## generating video: {index} => {final_video_path}")
@ -141,10 +161,16 @@ def start(task_id, params: VideoParams):
output_file=final_video_path,
params=params,
)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path)
logger.success(f"task {task_id} finished, generated {len(final_video_paths)} videos.")
return {
kwargs = {
"videos": final_video_paths,
}
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs

View File

@ -149,7 +149,7 @@ def text_to_srt(idx: int, msg: str, start_time: float, end_time: float) -> str:
def str_contains_punctuation(word):
for p in const.punctuations:
for p in const.PUNCTUATIONS:
if p in word:
return True
return False
@ -159,7 +159,7 @@ def split_string_by_punctuations(s):
result = []
txt = ""
for char in s:
if char not in const.punctuations:
if char not in const.PUNCTUATIONS:
txt += char
else:
result.append(txt.strip())

View File

@ -97,6 +97,20 @@
# ffmpeg_path = "C:\\Users\\harry\\Downloads\\ffmpeg.exe"
#########################################################################################
# 当视频生成成功后API服务提供的视频下载接入点默认为当前服务的地址和监听端口
# 比如 http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# 如果你需要使用域名对外提供服务一般会用nginx做代理则可以设置为你的域名
# 比如 https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# endpoint="https://xxxx.com"
# When the video is successfully generated, the API service provides a download endpoint for the video, defaulting to the service's current address and listening port.
# For example, http://127.0.0.1:8080/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# If you need to provide the service externally using a domain name (usually done with nginx as a proxy), you can set it to your domain name.
# For example, https://xxxx.com/tasks/6357f542-a4e1-46a1-b4c9-bf3bd0df5285/final-1.mp4
# endpoint="https://xxxx.com"
endpoint=""
[whisper]
# Only effective when subtitle_provider is "whisper"

View File

@ -0,0 +1,19 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>MoneyPrinterTurbo</title>
</head>
<body>
<h1>MoneyPrinterTurbo</h1>
<a href="https://github.com/harry0703/MoneyPrinterTurbo">https://github.com/harry0703/MoneyPrinterTurbo</a>
<p>
只需提供一个视频 主题 或 关键词 ,就可以全自动生成视频文案、视频素材、视频字幕、视频背景音乐,然后合成一个高清的短视频。
</p>
<p>
Simply provide a topic or keyword for a video, and it will automatically generate the video copy, video materials,
video subtitles, and video background music before synthesizing a high-definition short video.
</p>
</body>
</html>