feat: add redis support for task state management

This commit is contained in:
kevin.zhang 2024-04-10 10:42:56 +08:00
parent a0944fa358
commit 3d45348662
5 changed files with 111 additions and 44 deletions

View File

@ -29,7 +29,7 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"request_id": request_id,
"params": body.dict(),
}
sm.update_task(task_id)
sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
@ -46,7 +46,7 @@ def get_task(request: Request, task_id: str = Path(..., description="Task ID"),
endpoint = endpoint.rstrip("/")
request_id = base.get_task_id(request)
task = sm.get_task(task_id)
task = sm.state.get_task(task_id)
if task:
task_dir = utils.task_dir()

View File

@ -1,35 +1,96 @@
# State Management
# This module is responsible for managing the state of the application.
import math
# 如果你部署在分布式环境中,你可能需要一个中心化的状态管理服务,比如 Redis 或者数据库。
# 如果你的应用程序是单机的,你可以使用内存来存储状态。
# If you are deploying in a distributed environment, you might need a centralized state management service like Redis or a database.
# If your application is single-node, you can use memory to store the state.
import ast
import json
from abc import ABC, abstractmethod
import redis
from app.config import config
from app.models import const
from app.utils import utils
_tasks = {}
def update_task(task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
"""
Set the state of the task.
"""
progress = int(progress)
if progress > 100:
progress = 100
# Base class for state management
class BaseState(ABC):
_tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
@abstractmethod
def update_task(self, task_id: str, state: int, progress: int = 0, **kwargs):
pass
def get_task(task_id: str):
"""
Get the state of the task.
"""
return _tasks.get(task_id, None)
@abstractmethod
def get_task(self, task_id: str):
pass
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100
self._tasks[task_id] = {
"state": state,
"progress": progress,
**kwargs,
}
def get_task(self, task_id: str):
return self._tasks.get(task_id, None)
# Redis state management
class RedisState(BaseState):
def __init__(self, host='localhost', port=6379, db=0):
self._redis = redis.StrictRedis(host=host, port=port, db=db)
def update_task(self, task_id: str, state: int = const.TASK_STATE_PROCESSING, progress: int = 0, **kwargs):
progress = int(progress)
if progress > 100:
progress = 100
fields = {
"state": state,
"progress": progress,
**kwargs,
}
for field, value in fields.items():
self._redis.hset(task_id, field, str(value))
def get_task(self, task_id: str):
task_data = self._redis.hgetall(task_id)
if not task_data:
return None
task = {key.decode('utf-8'): self._convert_to_original_type(value) for key, value in task_data.items()}
return task
@staticmethod
def _convert_to_original_type(value):
"""
Convert the value from byte string to its original data type.
You can extend this method to handle other data types as needed.
"""
value_str = value.decode('utf-8')
try:
# try to convert byte string array to list
return ast.literal_eval(value_str)
except (ValueError, SyntaxError):
pass
if value_str.isdigit():
return int(value_str)
# Add more conversions here if needed
return value_str
# Global state
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
state = RedisState(host=_redis_host, port=_redis_port, db=_redis_db) if _enable_redis else MemoryState()

View File

@ -28,7 +28,7 @@ def start(task_id, params: VideoParams):
}
"""
logger.info(f"start task: {task_id}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=5)
video_subject = params.video_subject
voice_name = voice.parse_voice_name(params.voice_name)
@ -44,7 +44,7 @@ def start(task_id, params: VideoParams):
else:
logger.debug(f"video script: \n{video_script}")
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=10)
logger.info("\n\n## generating video terms")
video_terms = params.video_terms
@ -70,13 +70,13 @@ def start(task_id, params: VideoParams):
with open(script_file, "w", encoding="utf-8") as f:
f.write(utils.to_json(script_data))
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=20)
logger.info("\n\n## generating audio")
audio_file = path.join(utils.task_dir(task_id), f"audio.mp3")
sub_maker = voice.tts(text=video_script, voice_name=voice_name, voice_file=audio_file)
if sub_maker is None:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to generate audio, maybe the network is not available. if you are in China, please use a VPN.")
return
@ -84,7 +84,7 @@ def start(task_id, params: VideoParams):
audio_duration = voice.get_audio_duration(sub_maker)
audio_duration = math.ceil(audio_duration)
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=30)
subtitle_path = ""
if params.subtitle_enabled:
@ -108,7 +108,7 @@ def start(task_id, params: VideoParams):
logger.warning(f"subtitle file is invalid: {subtitle_path}")
subtitle_path = ""
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=40)
logger.info("\n\n## downloading videos")
downloaded_videos = material.download_videos(task_id=task_id,
@ -119,12 +119,12 @@ def start(task_id, params: VideoParams):
max_clip_duration=max_clip_duration,
)
if not downloaded_videos:
sm.update_task(task_id, state=const.TASK_STATE_FAILED)
sm.state.update_task(task_id, state=const.TASK_STATE_FAILED)
logger.error(
"failed to download videos, maybe the network is not available. if you are in China, please use a VPN.")
return
sm.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
sm.state.update_task(task_id, state=const.TASK_STATE_PROCESSING, progress=50)
final_video_paths = []
combined_video_paths = []
@ -146,7 +146,7 @@ def start(task_id, params: VideoParams):
threads=n_threads)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
sm.state.update_task(task_id, progress=_progress)
final_video_path = path.join(utils.task_dir(task_id), f"final-{index}.mp4")
@ -160,7 +160,7 @@ def start(task_id, params: VideoParams):
)
_progress += 50 / params.video_count / 2
sm.update_task(task_id, progress=_progress)
sm.state.update_task(task_id, progress=_progress)
final_video_paths.append(final_video_path)
combined_video_paths.append(combined_video_path)
@ -171,5 +171,5 @@ def start(task_id, params: VideoParams):
"videos": final_video_paths,
"combined_videos": combined_video_paths
}
sm.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs

View File

@ -129,6 +129,11 @@
material_directory = ""
# Used for state management of the task
enable_redis = true
redis_host = "localhost"
redis_port = 6379
redis_db = 0
[whisper]
# Only effective when subtitle_provider is "whisper"

View File

@ -15,4 +15,5 @@ pydantic~=2.6.3
g4f~=0.2.5.4
dashscope~=1.15.0
google.generativeai~=0.4.1
python-multipart~=0.0.9
python-multipart~=0.0.9
redis==5.0.3