Merge pull request #603 from garymengcom/main

Add get_all_tasks() endpoint and update .gitignore
This commit is contained in:
Harry 2025-03-23 18:40:52 +08:00 committed by GitHub
commit 6d2e4a8081
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 55 additions and 1 deletions

5
.gitignore vendored
View File

@ -22,4 +22,7 @@ node_modules
/sites/docs/.vuepress/dist
# 模型目录
/models/
./models/*
./models/*
venv/
.venv

View File

@ -94,6 +94,22 @@ def create_task(
task_id=task_id, status_code=400, message=f"{request_id}: {str(e)}"
)
from fastapi import Query
@router.get("/tasks", response_model=TaskQueryResponse, summary="Get all tasks")
def get_all_tasks(request: Request, page: int = Query(1, ge=1), page_size: int = Query(10, ge=1)):
request_id = base.get_task_id(request)
tasks, total = sm.state.get_all_tasks(page, page_size)
response = {
"tasks": tasks,
"total": total,
"page": page,
"page_size": page_size,
}
return utils.get_response(200, response)
@router.get(
"/tasks/{task_id}", response_model=TaskQueryResponse, summary="Query task status"

View File

@ -15,12 +15,23 @@ class BaseState(ABC):
def get_task(self, task_id: str):
pass
@abstractmethod
def get_all_tasks(self, page: int, page_size: int):
pass
# Memory state management
class MemoryState(BaseState):
def __init__(self):
self._tasks = {}
def get_all_tasks(self, page: int, page_size: int):
start = (page - 1) * page_size
end = start + page_size
tasks = list(self._tasks.values())
total = len(tasks)
return tasks[start:end], total
def update_task(
self,
task_id: str,
@ -33,6 +44,7 @@ class MemoryState(BaseState):
progress = 100
self._tasks[task_id] = {
"task_id": task_id,
"state": state,
"progress": progress,
**kwargs,
@ -53,6 +65,28 @@ class RedisState(BaseState):
self._redis = redis.StrictRedis(host=host, port=port, db=db, password=password)
def get_all_tasks(self, page: int, page_size: int):
start = (page - 1) * page_size
end = start + page_size
tasks = []
cursor = 0
total = 0
while True:
cursor, keys = self._redis.scan(cursor, count=page_size)
total += len(keys)
if total > start:
for key in keys[max(0, start - total):end - total]:
task_data = self._redis.hgetall(key)
task = {
k.decode("utf-8"): self._convert_to_original_type(v) for k, v in task_data.items()
}
tasks.append(task)
if len(tasks) >= page_size:
break
if cursor == 0 or len(tasks) >= page_size:
break
return tasks, total
def update_task(
self,
task_id: str,
@ -65,6 +99,7 @@ class RedisState(BaseState):
progress = 100
fields = {
"task_id": task_id,
"state": state,
"progress": progress,
**kwargs,