Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MS-731][FE] - upload dataset endpoint #394

Merged
merged 17 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions moonshot/integrations/web_api/routes/dataset.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile

from ..container import Container
from ..schemas.dataset_create_dto import CSV_Dataset_DTO, HF_Dataset_DTO
from ..schemas.dataset_response_dto import DatasetResponseDTO
from ..services.dataset_service import DatasetService
from ..services.utils.exceptions_handler import ServiceException
import tempfile
import os

router = APIRouter(tags=["Datasets"])


@router.post("/api/v1/datasets/csv")
@inject
def convert_dataset(
dataset_data: CSV_Dataset_DTO,
async def upload_dataset(
file: UploadFile = File(...),
name: str = Form(..., min_length=1),
description: str = Form(default="", min_length=1),
license: str = Form(default=""),
reference: str = Form(default=""),
dataset_service: DatasetService = Depends(Provide[Container.dataset_service]),
) -> str:
"""
Expand All @@ -32,6 +38,22 @@ def convert_dataset(
An error with status code 400 if there is a validation error.
An error with status code 500 for any other server-side error.
"""

# Create a temporary file with a secure random name
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[1]) as tmp_file:
content = await file.read()
tmp_file.write(content)
temp_file_path = tmp_file.name

# Create the DTO with the form data including optional fields
dataset_data = CSV_Dataset_DTO(
name=name,
description=description,
license=license,
reference=reference,
csv_file_path=temp_file_path
)

try:
return dataset_service.convert_dataset(dataset_data)
except ServiceException as e:
Expand All @@ -47,6 +69,10 @@ def convert_dataset(
raise HTTPException(
status_code=500, detail=f"Failed to convert dataset: {e.msg}"
)
finally:
# Clean up the temporary file
if os.path.exists(temp_file_path):
os.unlink(temp_file_path)


@router.post("/api/v1/datasets/hf")
Expand Down
3 changes: 2 additions & 1 deletion moonshot/integrations/web_api/services/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from ..services.base_service import BaseService
from ..services.utils.exceptions_handler import exception_handler
from .utils.file_manager import copy_file
import os


class DatasetService(BaseService):
Expand All @@ -29,7 +30,7 @@ def convert_dataset(self, dataset_data: CSV_Dataset_DTO) -> str:
license=dataset_data.license,
csv_file_path=dataset_data.csv_file_path,
)
return copy_file(new_ds_path)
return os.path.abspath(new_ds_path)

@exception_handler
def download_dataset(self, dataset_data: HF_Dataset_DTO) -> str:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ dependencies = [
"datasets>=2.21.0",
"pandas>=2.2.2",
"numpy>=1.26.4",
"tenacity>=8.5.0"
"tenacity>=8.5.0",
"python-multipart>=0.0.9",
]

[project.optional-dependencies]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pyreadline3~=3.4.1 ; python_version >= "3.11" and python_version < "3.12" and sy
python-dateutil~=2.9.0.post0 ; python_version >= "3.11" and python_version < "3.12"
python-dotenv~=1.0.1 ; python_version >= "3.11" and python_version < "3.12"
python-slugify~=8.0.4 ; python_version >= "3.11" and python_version < "3.12"
python-multipart~=0.0.9 ; python_version >= "3.11" and python_version < "3.12"
pytz~=2024.2 ; python_version >= "3.11" and python_version < "3.12"
pyyaml~=6.0.2 ; python_version >= "3.11" and python_version < "3.12"
requests~=2.32.3 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
Loading