feat: add support for maximum concurrency of /api/v1/videos

This commit is contained in:
kevin.zhang 2024-04-16 17:47:56 +08:00
parent 414bcb0621
commit abe12abd7b
7 changed files with 170 additions and 2 deletions

View File

@ -0,0 +1,57 @@
import threading
from typing import Callable, Any, Dict
class TaskManager:
def __init__(self, max_concurrent_tasks: int):
self.max_concurrent_tasks = max_concurrent_tasks
self.current_tasks = 0
self.lock = threading.Lock()
self.queue = self.create_queue()
def create_queue(self):
raise NotImplementedError()
def add_task(self, func: Callable, *args: Any, **kwargs: Any):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks:
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
self.execute_task(func, *args, **kwargs)
else:
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
self.enqueue({"func": func, "args": args, "kwargs": kwargs})
def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
thread.start()
def run_task(self, func: Callable, *args: Any, **kwargs: Any):
try:
with self.lock:
self.current_tasks += 1
func(*args, **kwargs) # 在这里调用函数,传递*args和**kwargs
finally:
self.task_done()
def check_queue(self):
with self.lock:
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
task_info = self.dequeue()
func = task_info['func']
args = task_info.get('args', ())
kwargs = task_info.get('kwargs', {})
self.execute_task(func, *args, **kwargs)
def task_done(self):
with self.lock:
self.current_tasks -= 1
self.check_queue()
def enqueue(self, task: Dict):
raise NotImplementedError()
def dequeue(self):
raise NotImplementedError()
def is_queue_empty(self):
raise NotImplementedError()

View File

@ -0,0 +1,18 @@
from queue import Queue
from typing import Dict
from app.controllers.manager.base_manager import TaskManager
class InMemoryTaskManager(TaskManager):
def create_queue(self):
return Queue()
def enqueue(self, task: Dict):
self.queue.put(task)
def dequeue(self):
return self.queue.get()
def is_queue_empty(self):
return self.queue.empty()

View File

@ -0,0 +1,48 @@
import json
from typing import Dict
import redis
from app.controllers.manager.base_manager import TaskManager
from app.models.schema import VideoParams
from app.services import task as tm
FUNC_MAP = {
'start': tm.start,
# 'start_test': tm.start_test
}
class RedisTaskManager(TaskManager):
def __init__(self, max_concurrent_tasks: int, redis_url: str):
self.redis_client = redis.Redis.from_url(redis_url)
super().__init__(max_concurrent_tasks)
def create_queue(self):
return "task_queue"
def enqueue(self, task: Dict):
task_with_serializable_params = task.copy()
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
# 将函数对象转换为其名称
task_with_serializable_params['func'] = task['func'].__name__
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
def dequeue(self):
task_json = self.redis_client.lpop(self.queue)
if task_json:
task_info = json.loads(task_json)
# 将函数名称转换回函数对象
task_info['func'] = FUNC_MAP[task_info['func']]
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
return task_info
return None
def is_queue_empty(self):
return self.redis_client.llen(self.queue) == 0

View File

@ -10,6 +10,8 @@ from loguru import logger
from app.config import config
from app.controllers import base
from app.controllers.manager.memory_manager import InMemoryTaskManager
from app.controllers.manager.redis_manager import RedisTaskManager
from app.controllers.v1.base import new_router
from app.models.exception import HttpException
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
@ -22,6 +24,35 @@ from app.utils import utils
# router = new_router(dependencies=[Depends(base.verify_token)])
router = new_router()
_enable_redis = config.app.get("enable_redis", False)
_redis_host = config.app.get("redis_host", "localhost")
_redis_port = config.app.get("redis_port", 6379)
_redis_db = config.app.get("redis_db", 0)
_redis_password = config.app.get("redis_password", None)
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
# 根据配置选择合适的任务管理器
if _enable_redis:
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
else:
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
# async def create_video_test(request: Request, body: TaskVideoRequest):
# task_id = utils.get_uuid()
# request_id = base.get_task_id(request)
# try:
# task = {
# "task_id": task_id,
# "request_id": request_id,
# "params": body.dict(),
# }
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
# return utils.get_response(200, task)
# except ValueError as e:
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
@ -34,7 +65,8 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
"params": body.dict(),
}
sm.state.update_task(task_id)
background_tasks.add_task(tm.start, task_id=task_id, params=body)
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
task_manager.add_task(tm.start, task_id=task_id, params=body)
logger.success(f"video created: {utils.to_json(task)}")
return utils.get_response(200, task)
except ValueError as e:

View File

@ -73,7 +73,7 @@ class MaterialInfo:
# ]
class VideoParams:
class VideoParams(BaseModel):
"""
{
"video_subject": "",

View File

@ -173,3 +173,9 @@ def start(task_id, params: VideoParams):
}
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
return kwargs
# def start_test(task_id, params: VideoParams):
# print(f"start task {task_id} \n")
# time.sleep(5)
# print(f"task {task_id} finished \n")

View File

@ -1,3 +1,6 @@
listen_host = "127.0.0.1"
listen_port = 8502
[app]
# Pexels API Key
# Register at https://www.pexels.com/api/ to get your API key.
@ -134,6 +137,10 @@
redis_host = "localhost"
redis_port = 6379
redis_db = 0
redis_password = ""
# 文生视频时的最大并发任务数
max_concurrent_tasks = 5
[whisper]
# Only effective when subtitle_provider is "whisper"