diff --git a/app/controllers/manager/base_manager.py b/app/controllers/manager/base_manager.py new file mode 100644 index 0000000..99cbf6f --- /dev/null +++ b/app/controllers/manager/base_manager.py @@ -0,0 +1,57 @@ +import threading +from typing import Callable, Any, Dict + + +class TaskManager: + def __init__(self, max_concurrent_tasks: int): + self.max_concurrent_tasks = max_concurrent_tasks + self.current_tasks = 0 + self.lock = threading.Lock() + self.queue = self.create_queue() + + def create_queue(self): + raise NotImplementedError() + + def add_task(self, func: Callable, *args: Any, **kwargs: Any): + with self.lock: + if self.current_tasks < self.max_concurrent_tasks: + print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}") + self.execute_task(func, *args, **kwargs) + else: + print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}") + self.enqueue({"func": func, "args": args, "kwargs": kwargs}) + + def execute_task(self, func: Callable, *args: Any, **kwargs: Any): + thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs) + thread.start() + + def run_task(self, func: Callable, *args: Any, **kwargs: Any): + try: + with self.lock: + self.current_tasks += 1 + func(*args, **kwargs) # 在这里调用函数,传递*args和**kwargs + finally: + self.task_done() + + def check_queue(self): + with self.lock: + if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty(): + task_info = self.dequeue() + func = task_info['func'] + args = task_info.get('args', ()) + kwargs = task_info.get('kwargs', {}) + self.execute_task(func, *args, **kwargs) + + def task_done(self): + with self.lock: + self.current_tasks -= 1 + self.check_queue() + + def enqueue(self, task: Dict): + raise NotImplementedError() + + def dequeue(self): + raise NotImplementedError() + + def is_queue_empty(self): + raise NotImplementedError() diff --git a/app/controllers/manager/memory_manager.py b/app/controllers/manager/memory_manager.py new file mode 100644 index 0000000..cf7321f --- /dev/null +++ b/app/controllers/manager/memory_manager.py @@ -0,0 +1,18 @@ +from queue import Queue +from typing import Dict + +from app.controllers.manager.base_manager import TaskManager + + +class InMemoryTaskManager(TaskManager): + def create_queue(self): + return Queue() + + def enqueue(self, task: Dict): + self.queue.put(task) + + def dequeue(self): + return self.queue.get() + + def is_queue_empty(self): + return self.queue.empty() diff --git a/app/controllers/manager/redis_manager.py b/app/controllers/manager/redis_manager.py new file mode 100644 index 0000000..a37c26c --- /dev/null +++ b/app/controllers/manager/redis_manager.py @@ -0,0 +1,48 @@ +import json +from typing import Dict + +import redis + +from app.controllers.manager.base_manager import TaskManager +from app.models.schema import VideoParams +from app.services import task as tm + +FUNC_MAP = { + 'start': tm.start, + # 'start_test': tm.start_test +} + + +class RedisTaskManager(TaskManager): + def __init__(self, max_concurrent_tasks: int, redis_url: str): + self.redis_client = redis.Redis.from_url(redis_url) + super().__init__(max_concurrent_tasks) + + def create_queue(self): + return "task_queue" + + def enqueue(self, task: Dict): + task_with_serializable_params = task.copy() + + if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams): + task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict() + + # 将函数对象转换为其名称 + task_with_serializable_params['func'] = task['func'].__name__ + self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params)) + + def dequeue(self): + task_json = self.redis_client.lpop(self.queue) + if task_json: + task_info = json.loads(task_json) + # 将函数名称转换回函数对象 + task_info['func'] = FUNC_MAP[task_info['func']] + + if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict): + task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params']) + + return task_info + return None + + def is_queue_empty(self): + return self.redis_client.llen(self.queue) == 0 diff --git a/app/controllers/v1/video.py b/app/controllers/v1/video.py index ab1a3a0..cc3ca3b 100644 --- a/app/controllers/v1/video.py +++ b/app/controllers/v1/video.py @@ -10,6 +10,8 @@ from loguru import logger from app.config import config from app.controllers import base +from app.controllers.manager.memory_manager import InMemoryTaskManager +from app.controllers.manager.redis_manager import RedisTaskManager from app.controllers.v1.base import new_router from app.models.exception import HttpException from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \ @@ -22,6 +24,35 @@ from app.utils import utils # router = new_router(dependencies=[Depends(base.verify_token)]) router = new_router() +_enable_redis = config.app.get("enable_redis", False) +_redis_host = config.app.get("redis_host", "localhost") +_redis_port = config.app.get("redis_port", 6379) +_redis_db = config.app.get("redis_db", 0) +_redis_password = config.app.get("redis_password", None) +_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5) + +redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}" +# 根据配置选择合适的任务管理器 +if _enable_redis: + task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url) +else: + task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks) + +# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video") +# async def create_video_test(request: Request, body: TaskVideoRequest): +# task_id = utils.get_uuid() +# request_id = base.get_task_id(request) +# try: +# task = { +# "task_id": task_id, +# "request_id": request_id, +# "params": body.dict(), +# } +# task_manager.add_task(tm.start_test, task_id=task_id, params=body) +# return utils.get_response(200, task) +# except ValueError as e: +# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}") + @router.post("/videos", response_model=TaskResponse, summary="Generate a short video") def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest): @@ -34,7 +65,8 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task "params": body.dict(), } sm.state.update_task(task_id) - background_tasks.add_task(tm.start, task_id=task_id, params=body) + # background_tasks.add_task(tm.start, task_id=task_id, params=body) + task_manager.add_task(tm.start, task_id=task_id, params=body) logger.success(f"video created: {utils.to_json(task)}") return utils.get_response(200, task) except ValueError as e: diff --git a/app/models/schema.py b/app/models/schema.py index 29cf732..50ee918 100644 --- a/app/models/schema.py +++ b/app/models/schema.py @@ -73,7 +73,7 @@ class MaterialInfo: # ] -class VideoParams: +class VideoParams(BaseModel): """ { "video_subject": "", diff --git a/app/services/task.py b/app/services/task.py index 595ddc0..05fbc12 100644 --- a/app/services/task.py +++ b/app/services/task.py @@ -173,3 +173,9 @@ def start(task_id, params: VideoParams): } sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs) return kwargs + + +# def start_test(task_id, params: VideoParams): +# print(f"start task {task_id} \n") +# time.sleep(5) +# print(f"task {task_id} finished \n") diff --git a/config.example.toml b/config.example.toml index c19ac57..0348782 100644 --- a/config.example.toml +++ b/config.example.toml @@ -1,3 +1,6 @@ +listen_host = "127.0.0.1" +listen_port = 8502 + [app] # Pexels API Key # Register at https://www.pexels.com/api/ to get your API key. @@ -134,6 +137,10 @@ redis_host = "localhost" redis_port = 6379 redis_db = 0 + redis_password = "" + + # 文生视频时的最大并发任务数 + max_concurrent_tasks = 5 [whisper] # Only effective when subtitle_provider is "whisper"