Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to upload artifacts #26

Merged
merged 2 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 36 additions & 7 deletions agent/python/agent_protocol/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
import os

from fastapi import APIRouter
import aiofiles
from fastapi import APIRouter, UploadFile
from fastapi.responses import FileResponse
from hypercorn.asyncio import serve
from hypercorn.config import Config
Expand Down Expand Up @@ -100,10 +101,6 @@ async def execute_agent_task_step(
step = await _step_handler(step)

step.status = Status.completed

if step.artifacts:
task.artifacts.extend(step.artifacts)

return step


Expand Down Expand Up @@ -132,6 +129,31 @@ async def list_agent_task_artifacts(task_id: str) -> List[Artifact]:
return task.artifacts


@base_router.post(
"/agent/tasks/{task_id}/artifacts",
response_model=Artifact,
tags=["agent"],
)
async def upload_agent_task_artifacts(
task_id: str, file: UploadFile, relative_path: Optional[str] = None
) -> Artifact:
"""
Upload an artifact for the specified task.
"""
await Agent.db.get_task(task_id)
artifact = await Agent.db.create_artifact(task_id, file.filename, relative_path)

path = Agent.get_artifact_folder(task_id, artifact)
if not os.path.exists(path):
os.makedirs(path)

async with aiofiles.open(os.path.join(path, file.filename), "wb") as f:
while content := await file.read(1024 * 1024): # async read chunk ~1MiB
await f.write(content)

return artifact


@base_router.get(
"/agent/tasks/{task_id}/artifacts/{artifact_id}",
tags=["agent"],
Expand Down Expand Up @@ -172,13 +194,20 @@ def get_workspace(task_id: str) -> str:
return os.path.join(os.getcwd(), Agent.workspace, task_id)

@staticmethod
def get_artifact_path(task_id: str, artifact: Artifact) -> str:
def get_artifact_folder(task_id: str, artifact: Artifact) -> str:
"""
Get the artifact path for the specified task and artifact.
"""
workspace_path = Agent.get_workspace(task_id)
relative_path = artifact.relative_path or ""
return os.path.join(workspace_path, relative_path, artifact.file_name)
return os.path.join(workspace_path, relative_path)

@staticmethod
def get_artifact_path(task_id: str, artifact: Artifact) -> str:
"""
Get the artifact path for the specified task and artifact.
"""
return os.path.join(Agent.get_artifact_folder(task_id, artifact), artifact.file_name)

@staticmethod
def start(port: int = 8000, router: APIRouter = base_router):
Expand Down
29 changes: 29 additions & 0 deletions agent/python/agent_protocol/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ async def create_step(
) -> Step:
raise NotImplementedError

async def create_artifact(
self,
task_id: str,
file_name: str,
relative_path: Optional[str] = None,
step_id: Optional[str] = None,
) -> Artifact:
raise NotImplementedError

async def get_task(self, task_id: str) -> Task:
raise NotImplementedError

Expand Down Expand Up @@ -114,6 +123,26 @@ async def get_artifact(self, task_id: str, artifact_id: str) -> Artifact:
raise Exception(f"Artifact with id {artifact_id} not found")
return artifact

async def create_artifact(
self,
task_id: str,
file_name: str,
relative_path: Optional[str] = None,
step_id: Optional[str] = None,
) -> Artifact:
artifact_id = str(uuid.uuid4())
artifact = Artifact(
artifact_id=artifact_id, file_name=file_name, relative_path=relative_path
)
task = await self.get_task(task_id)
task.artifacts.append(artifact)

if step_id:
step = await self.get_step(task_id, step_id)
step.artifacts.append(artifact)

return artifact

async def list_tasks(self) -> List[Task]:
return [task for task in self._tasks.values()]

Expand Down
7 changes: 7 additions & 0 deletions agent/python/agent_protocol/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ class Artifact(BaseModel):
)


class ArtifactUpload(BaseModel):
file: bytes = Field(..., description="File to upload.")
relative_path: Optional[str] = Field(
None, description="Relative path of the artifact in the agent's workspace."
)


class StepInput(BaseModel):
__root__: Any = Field(
..., description="Input parameters for the task step. Any value is allowed."
Expand Down
Empty file.
11 changes: 6 additions & 5 deletions agent/python/examples/smol_developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,17 @@ async def _generate_code(task: Task, step: Step) -> Step:
file_path = step.additional_properties["file_path"]

code = await generate_code(task.input, shared_deps, file_path)
step.output = code

write_file(os.path.join(Agent.get_workspace(task.task_id), file_path), code)
path = Path("./" + file_path)
artifact = Artifact(
artifact_id=str(uuid.uuid4()),
file_name=path.name,
await Agent.db.create_artifact(
task_id=task.task_id,
step_id=step.step_id,
relative_path=str(path.parent),
file_name=path.name,
)

step.output = code
step.artifacts.append(artifact)
return step


Expand Down
27 changes: 26 additions & 1 deletion agent/python/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion agent/python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "agent-protocol"
version = "0.2.2"
version = "0.2.3"
description = "API for interacting with Agent"
authors = ["e2b <hello@e2b.dev>"]
license = "MIT"
Expand All @@ -17,6 +17,8 @@ pytest = "^7.0.0"
pydantic = "^1.10.5, <2"
click = "^8.1.6"
requests = "^2.31.0"
python-multipart = "^0.0.6"
aiofiles = "^23.1.0"

[tool.poetry.group.dev.dependencies]
fastapi-code-generator = "^0.4.2"
Expand Down
Loading