Skip to content

Commit

Permalink
Add pin task button, reference issue ArtVentureX#235
Browse files Browse the repository at this point in the history
- add new icon for pinned tasks
- add api do_pin_and_request, pin, unpin
- add pin and run button to history tab
- add pin button to pending tab
- allow modifying of pending tasks when they are pinned
  • Loading branch information
Nyxeka committed Apr 2, 2024
1 parent 721a36f commit 4c22369
Show file tree
Hide file tree
Showing 14 changed files with 277 additions and 115 deletions.
36 changes: 36 additions & 0 deletions agent_scheduler/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,24 @@ def run_task(id: str):

return {"success": True, "message": "Task is executing"}

@app.post("/agent-scheduler/v1/task/{id}/pin", dependencies=deps)
def pin_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": "Task not found"}
task.pinned = True
task_manager.update_task(task)
return {"success": True, "message": "Task pinned successfully"}

@app.post("/agent-scheduler/v1/task/{id}/unpin", dependencies=deps)
def unpin_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": "Task not found"}
task.pinned = False
task_manager.update_task(task)
return {"success": True, "message": "Task unpinned successfully"}

@app.post("/agent-scheduler/v1/requeue/{id}", dependencies=deps, deprecated=True)
@app.post("/agent-scheduler/v1/task/{id}/requeue", dependencies=deps)
def requeue_task(id: str):
Expand All @@ -347,11 +365,29 @@ def requeue_task(id: str):
task.result = None
task.status = TaskStatus.PENDING
task.bookmarked = False
task.pinned = False
task.name = f"Copy of {task.name}" if task.name else None
task_manager.add_task(task)
task_runner.execute_pending_tasks_threading()

return {"success": True, "message": "Task requeued"}
# /agent-scheduler/v1/task/${id}/requeue-and-pin
@app.post("/agent-scheduler/v1/task/{id}/do-pin-and-requeue", dependencies=deps)
def requeue_and_pin_task(id: str):
task = task_manager.get_task(id)
if task is None:
return {"success": False, "message": "Task not found"}

task.id = str(uuid4())
task.result = None
task.status = TaskStatus.PENDING
task.bookmarked = False
task.pinned = True
task.name = f"Copy of {task.name}" if task.name else None
task_manager.add_task(task)
task_runner.execute_pending_tasks_threading()

return {"success": True, "message": "Task requeued and pinned"}

@app.post("/agent-scheduler/v1/task/requeue-failed", dependencies=deps)
def requeue_failed_tasks():
Expand Down
4 changes: 4 additions & 0 deletions agent_scheduler/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def init():
if not any(col["name"] == "bookmarked" for col in task_columns):
conn.execute(text("ALTER TABLE task ADD COLUMN bookmarked BOOLEAN DEFAULT FALSE"))

# add pinned column
if not any(col["name"] == "pinned" for col in task_columns):
conn.execute(text("ALTER TABLE task ADD COLUMN pinned BOOLEAN DEFAULT FALSE"))

params_column = next(col for col in task_columns if col["name"] == "params")
if version > "1" and not isinstance(params_column["type"], Text):
transaction = conn.begin()
Expand Down
10 changes: 10 additions & 0 deletions agent_scheduler/db/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def from_table(table: "TaskTable"):
status=table.status,
result=table.result,
bookmarked=table.bookmarked,
pinned=table.pinned,
created_at=table.created_at,
updated_at=table.updated_at,
)
Expand All @@ -89,6 +90,7 @@ def to_table(self):
status=self.status,
result=self.result,
bookmarked=self.bookmarked,
pinned=self.pinned,
)

def from_json(json_obj: Dict):
Expand All @@ -104,6 +106,7 @@ def from_json(json_obj: Dict):
priority=json_obj.get("priority", int(datetime.now(timezone.utc).timestamp() * 1000)),
result=json_obj.get("result", None),
bookmarked=json_obj.get("bookmarked", False),
pinned=json_obj.get("pinned", False),
created_at=datetime.fromtimestamp(json_obj.get("created_at", datetime.now(timezone.utc).timestamp())),
updated_at=datetime.fromtimestamp(json_obj.get("updated_at", datetime.now(timezone.utc).timestamp())),
)
Expand All @@ -121,6 +124,7 @@ def to_json(self):
"priority": self.priority,
"result": self.result,
"bookmarked": self.bookmarked,
"pinned": self.pinned,
"created_at": int(self.created_at.timestamp()),
"updated_at": int(self.updated_at.timestamp()),
}
Expand All @@ -140,6 +144,7 @@ class TaskTable(Base):
status = Column(String(20), nullable=False, default="pending") # pending, running, done, failed
result = Column(Text) # task result
bookmarked = Column(Boolean, nullable=True, default=False)
pinned = Column(Boolean, nullable=True, default=False)
created_at = Column(
DateTime,
nullable=False,
Expand Down Expand Up @@ -197,6 +202,7 @@ def get_tasks(
limit: int = None,
offset: int = None,
order: str = "asc",
pinned: bool = None,
) -> List[TaskTable]:
session = Session(self.engine)
try:
Expand All @@ -213,6 +219,10 @@ def get_tasks(
if api_task_id:
query = query.filter(TaskTable.api_task_id == api_task_id)


if pinned is not None:
query = query.filter(TaskTable.pinned == pinned)

if bookmarked == True:
query = query.filter(TaskTable.bookmarked == bookmarked)
else:
Expand Down
1 change: 1 addition & 0 deletions agent_scheduler/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class TaskModel(BaseModel):
position: Optional[int] = Field(title="Task Position")
result: Optional[str] = Field(title="Task Result", description="The result of the task in JSON format")
bookmarked: Optional[bool] = Field(title="Is task bookmarked")
pinned: Optional[bool] = Field(title="Is task pinned")
created_at: Optional[datetime] = Field(
title="Task Created At",
description="The time when the task was created",
Expand Down
31 changes: 28 additions & 3 deletions agent_scheduler/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(self, UiControlNetUnit=None):
# Mark this to True when reload UI
self.dispose = False
self.interrupted = None
self.current_task_pin_status_changed = None

if TaskRunner.instance is not None:
raise Exception("TaskRunner instance already exists")
Expand Down Expand Up @@ -342,6 +343,7 @@ def execute_task(self, task: Task, get_next_task: Callable[[], Task]):
}

self.interrupted = None
self.current_task_pin_status_changed = None
self.__saved_images_path = []
self.__run_callbacks("task_started", task_id, **task_meta)

Expand Down Expand Up @@ -377,6 +379,7 @@ def execute_task(self, task: Task, get_next_task: Callable[[], Task]):
if is_interrupted:
log.info(f"\n[AgentScheduler] Task {task.id} interrupted")
task.status = TaskStatus.INTERRUPTED
task.pinned = False
task_manager.update_task(task)
self.__run_callbacks(
"task_finished",
Expand All @@ -390,10 +393,23 @@ def execute_task(self, task: Task, get_next_task: Callable[[], Task]):
"images": self.__saved_images_path.copy(),
"geninfo": geninfo,
}

task.status = TaskStatus.DONE

try:
# necessary to update in case it changed while pending
task.pinned = task_manager.get_task(task.id).pinned
except Exception as e:
task.pinned = False
log.error(f"[AgentScheduler] Error updating task {task.id} pinned status: {e}")

task.result = json.dumps(result)

if task.pinned:
task.priority = int(datetime.now(timezone.utc).timestamp() * 1000)
cbstatus = TaskStatus.DONE if not task.pinned else TaskStatus.PENDING
task.status = cbstatus

task_manager.update_task(task)

self.__run_callbacks(
"task_finished",
task_id,
Expand Down Expand Up @@ -520,8 +536,17 @@ def __get_pending_task(self):

# get more task if needed
if self.__total_pending_tasks > 0:

# to-do: implement task priority button or something?

# first search non-pinned:
log.info(f"[AgentScheduler] Total pending tasks: {self.__total_pending_tasks}")
pending_tasks = task_manager.get_tasks(status="pending", limit=1)
pending_tasks = task_manager.get_tasks(status="pending", limit=1, pinned=False)
if len(pending_tasks) > 0:
return pending_tasks[0]

# finally, look for pinned:
pending_tasks = task_manager.get_tasks(status="pending", limit=1, pinned=True)
if len(pending_tasks) > 0:
return pending_tasks[0]
else:
Expand Down
Loading

0 comments on commit 4c22369

Please sign in to comment.