mirror of
https://github.com/harry0703/MoneyPrinterTurbo.git
synced 2026-02-21 16:37:21 +08:00
commit
ab5ae7072b
@ -16,12 +16,15 @@ RUN apt-get update && apt-get install -y \
|
||||
# Fix security policy for ImageMagick
|
||||
RUN sed -i '/<policy domain="path" rights="none" pattern="@\*"/d' /etc/ImageMagick-6/policy.xml
|
||||
|
||||
# Copy the current directory contents into the container at /MoneyPrinterTurbo
|
||||
COPY . .
|
||||
# Copy only the requirements.txt first to leverage Docker cache
|
||||
COPY requirements.txt ./
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Now copy the rest of the codebase into the image
|
||||
COPY . .
|
||||
|
||||
# Expose the port the app runs on
|
||||
EXPOSE 8501
|
||||
|
||||
|
||||
@ -56,7 +56,7 @@ listen_port = _cfg.get("listen_port", 8080)
|
||||
project_name = _cfg.get("project_name", "MoneyPrinterTurbo")
|
||||
project_description = _cfg.get("project_description",
|
||||
"<a href='https://github.com/harry0703/MoneyPrinterTurbo'>https://github.com/harry0703/MoneyPrinterTurbo</a>")
|
||||
project_version = _cfg.get("project_version", "1.1.2")
|
||||
project_version = _cfg.get("project_version", "1.1.3")
|
||||
reload_debug = False
|
||||
|
||||
imagemagick_path = app.get("imagemagick_path", "")
|
||||
|
||||
57
app/controllers/manager/base_manager.py
Normal file
57
app/controllers/manager/base_manager.py
Normal file
@ -0,0 +1,57 @@
|
||||
import threading
|
||||
from typing import Callable, Any, Dict
|
||||
|
||||
|
||||
class TaskManager:
|
||||
def __init__(self, max_concurrent_tasks: int):
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
self.current_tasks = 0
|
||||
self.lock = threading.Lock()
|
||||
self.queue = self.create_queue()
|
||||
|
||||
def create_queue(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_task(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
with self.lock:
|
||||
if self.current_tasks < self.max_concurrent_tasks:
|
||||
print(f"add task: {func.__name__}, current_tasks: {self.current_tasks}")
|
||||
self.execute_task(func, *args, **kwargs)
|
||||
else:
|
||||
print(f"enqueue task: {func.__name__}, current_tasks: {self.current_tasks}")
|
||||
self.enqueue({"func": func, "args": args, "kwargs": kwargs})
|
||||
|
||||
def execute_task(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
thread = threading.Thread(target=self.run_task, args=(func, *args), kwargs=kwargs)
|
||||
thread.start()
|
||||
|
||||
def run_task(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
try:
|
||||
with self.lock:
|
||||
self.current_tasks += 1
|
||||
func(*args, **kwargs) # 在这里调用函数,传递*args和**kwargs
|
||||
finally:
|
||||
self.task_done()
|
||||
|
||||
def check_queue(self):
|
||||
with self.lock:
|
||||
if self.current_tasks < self.max_concurrent_tasks and not self.is_queue_empty():
|
||||
task_info = self.dequeue()
|
||||
func = task_info['func']
|
||||
args = task_info.get('args', ())
|
||||
kwargs = task_info.get('kwargs', {})
|
||||
self.execute_task(func, *args, **kwargs)
|
||||
|
||||
def task_done(self):
|
||||
with self.lock:
|
||||
self.current_tasks -= 1
|
||||
self.check_queue()
|
||||
|
||||
def enqueue(self, task: Dict):
|
||||
raise NotImplementedError()
|
||||
|
||||
def dequeue(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_queue_empty(self):
|
||||
raise NotImplementedError()
|
||||
18
app/controllers/manager/memory_manager.py
Normal file
18
app/controllers/manager/memory_manager.py
Normal file
@ -0,0 +1,18 @@
|
||||
from queue import Queue
|
||||
from typing import Dict
|
||||
|
||||
from app.controllers.manager.base_manager import TaskManager
|
||||
|
||||
|
||||
class InMemoryTaskManager(TaskManager):
|
||||
def create_queue(self):
|
||||
return Queue()
|
||||
|
||||
def enqueue(self, task: Dict):
|
||||
self.queue.put(task)
|
||||
|
||||
def dequeue(self):
|
||||
return self.queue.get()
|
||||
|
||||
def is_queue_empty(self):
|
||||
return self.queue.empty()
|
||||
48
app/controllers/manager/redis_manager.py
Normal file
48
app/controllers/manager/redis_manager.py
Normal file
@ -0,0 +1,48 @@
|
||||
import json
|
||||
from typing import Dict
|
||||
|
||||
import redis
|
||||
|
||||
from app.controllers.manager.base_manager import TaskManager
|
||||
from app.models.schema import VideoParams
|
||||
from app.services import task as tm
|
||||
|
||||
FUNC_MAP = {
|
||||
'start': tm.start,
|
||||
# 'start_test': tm.start_test
|
||||
}
|
||||
|
||||
|
||||
class RedisTaskManager(TaskManager):
|
||||
def __init__(self, max_concurrent_tasks: int, redis_url: str):
|
||||
self.redis_client = redis.Redis.from_url(redis_url)
|
||||
super().__init__(max_concurrent_tasks)
|
||||
|
||||
def create_queue(self):
|
||||
return "task_queue"
|
||||
|
||||
def enqueue(self, task: Dict):
|
||||
task_with_serializable_params = task.copy()
|
||||
|
||||
if 'params' in task['kwargs'] and isinstance(task['kwargs']['params'], VideoParams):
|
||||
task_with_serializable_params['kwargs']['params'] = task['kwargs']['params'].dict()
|
||||
|
||||
# 将函数对象转换为其名称
|
||||
task_with_serializable_params['func'] = task['func'].__name__
|
||||
self.redis_client.rpush(self.queue, json.dumps(task_with_serializable_params))
|
||||
|
||||
def dequeue(self):
|
||||
task_json = self.redis_client.lpop(self.queue)
|
||||
if task_json:
|
||||
task_info = json.loads(task_json)
|
||||
# 将函数名称转换回函数对象
|
||||
task_info['func'] = FUNC_MAP[task_info['func']]
|
||||
|
||||
if 'params' in task_info['kwargs'] and isinstance(task_info['kwargs']['params'], dict):
|
||||
task_info['kwargs']['params'] = VideoParams(**task_info['kwargs']['params'])
|
||||
|
||||
return task_info
|
||||
return None
|
||||
|
||||
def is_queue_empty(self):
|
||||
return self.redis_client.llen(self.queue) == 0
|
||||
@ -10,6 +10,8 @@ from loguru import logger
|
||||
|
||||
from app.config import config
|
||||
from app.controllers import base
|
||||
from app.controllers.manager.memory_manager import InMemoryTaskManager
|
||||
from app.controllers.manager.redis_manager import RedisTaskManager
|
||||
from app.controllers.v1.base import new_router
|
||||
from app.models.exception import HttpException
|
||||
from app.models.schema import TaskVideoRequest, TaskQueryResponse, TaskResponse, TaskQueryRequest, \
|
||||
@ -22,6 +24,35 @@ from app.utils import utils
|
||||
# router = new_router(dependencies=[Depends(base.verify_token)])
|
||||
router = new_router()
|
||||
|
||||
_enable_redis = config.app.get("enable_redis", False)
|
||||
_redis_host = config.app.get("redis_host", "localhost")
|
||||
_redis_port = config.app.get("redis_port", 6379)
|
||||
_redis_db = config.app.get("redis_db", 0)
|
||||
_redis_password = config.app.get("redis_password", None)
|
||||
_max_concurrent_tasks = config.app.get("max_concurrent_tasks", 5)
|
||||
|
||||
redis_url = f"redis://:{_redis_password}@{_redis_host}:{_redis_port}/{_redis_db}"
|
||||
# 根据配置选择合适的任务管理器
|
||||
if _enable_redis:
|
||||
task_manager = RedisTaskManager(max_concurrent_tasks=_max_concurrent_tasks, redis_url=redis_url)
|
||||
else:
|
||||
task_manager = InMemoryTaskManager(max_concurrent_tasks=_max_concurrent_tasks)
|
||||
|
||||
# @router.post("/videos-test", response_model=TaskResponse, summary="Generate a short video")
|
||||
# async def create_video_test(request: Request, body: TaskVideoRequest):
|
||||
# task_id = utils.get_uuid()
|
||||
# request_id = base.get_task_id(request)
|
||||
# try:
|
||||
# task = {
|
||||
# "task_id": task_id,
|
||||
# "request_id": request_id,
|
||||
# "params": body.dict(),
|
||||
# }
|
||||
# task_manager.add_task(tm.start_test, task_id=task_id, params=body)
|
||||
# return utils.get_response(200, task)
|
||||
# except ValueError as e:
|
||||
# raise HttpException(task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/videos", response_model=TaskResponse, summary="Generate a short video")
|
||||
def create_video(background_tasks: BackgroundTasks, request: Request, body: TaskVideoRequest):
|
||||
@ -34,7 +65,8 @@ def create_video(background_tasks: BackgroundTasks, request: Request, body: Task
|
||||
"params": body.dict(),
|
||||
}
|
||||
sm.state.update_task(task_id)
|
||||
background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||
# background_tasks.add_task(tm.start, task_id=task_id, params=body)
|
||||
task_manager.add_task(tm.start, task_id=task_id, params=body)
|
||||
logger.success(f"video created: {utils.to_json(task)}")
|
||||
return utils.get_response(200, task)
|
||||
except ValueError as e:
|
||||
|
||||
@ -73,7 +73,7 @@ class MaterialInfo:
|
||||
# ]
|
||||
|
||||
|
||||
class VideoParams:
|
||||
class VideoParams(BaseModel):
|
||||
"""
|
||||
{
|
||||
"video_subject": "",
|
||||
|
||||
@ -173,3 +173,9 @@ def start(task_id, params: VideoParams):
|
||||
}
|
||||
sm.state.update_task(task_id, state=const.TASK_STATE_COMPLETE, progress=100, **kwargs)
|
||||
return kwargs
|
||||
|
||||
|
||||
# def start_test(task_id, params: VideoParams):
|
||||
# print(f"start task {task_id} \n")
|
||||
# time.sleep(5)
|
||||
# print(f"task {task_id} finished \n")
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
listen_host = "127.0.0.1"
|
||||
listen_port = 8502
|
||||
|
||||
[app]
|
||||
# Pexels API Key
|
||||
# Register at https://www.pexels.com/api/ to get your API key.
|
||||
@ -134,6 +137,10 @@
|
||||
redis_host = "localhost"
|
||||
redis_port = 6379
|
||||
redis_db = 0
|
||||
redis_password = ""
|
||||
|
||||
# 文生视频时的最大并发任务数
|
||||
max_concurrent_tasks = 5
|
||||
|
||||
[whisper]
|
||||
# Only effective when subtitle_provider is "whisper"
|
||||
|
||||
BIN
resource/fonts/MicrosoftYaHeiBold.ttc
Normal file
BIN
resource/fonts/MicrosoftYaHeiBold.ttc
Normal file
Binary file not shown.
BIN
resource/fonts/MicrosoftYaHeiNormal.ttc
Normal file
BIN
resource/fonts/MicrosoftYaHeiNormal.ttc
Normal file
Binary file not shown.
@ -223,7 +223,7 @@ left_panel = panel[0]
|
||||
middle_panel = panel[1]
|
||||
right_panel = panel[2]
|
||||
|
||||
params = VideoParams()
|
||||
params = VideoParams(video_subject="")
|
||||
|
||||
with left_panel:
|
||||
with st.container(border=True):
|
||||
@ -299,7 +299,8 @@ with middle_panel:
|
||||
index=0)
|
||||
with st.container(border=True):
|
||||
st.write(tr("Audio Settings"))
|
||||
voices = voice.get_all_azure_voices(filter_locals=["zh-CN", "zh-HK", "zh-TW", "de-DE", "en-US", "fr-FR", "vi-VN"])
|
||||
voices = voice.get_all_azure_voices(
|
||||
filter_locals=["zh-CN", "zh-HK", "zh-TW", "de-DE", "en-US", "fr-FR", "vi-VN"])
|
||||
friendly_names = {
|
||||
v: v.
|
||||
replace("Female", tr("Female")).
|
||||
|
||||
Loading…
Reference in New Issue
Block a user