diff --git a/moonshot/integrations/web_api/routes/dataset.py b/moonshot/integrations/web_api/routes/dataset.py index e85ffb45..cd5fb646 100644 --- a/moonshot/integrations/web_api/routes/dataset.py +++ b/moonshot/integrations/web_api/routes/dataset.py @@ -1,5 +1,8 @@ +import os +import tempfile + from dependency_injector.wiring import Provide, inject -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile from ..container import Container from ..schemas.dataset_create_dto import CSV_Dataset_DTO, HF_Dataset_DTO @@ -10,10 +13,14 @@ router = APIRouter(tags=["Datasets"]) -@router.post("/api/v1/datasets/csv") +@router.post("/api/v1/datasets/file") @inject -def convert_dataset( - dataset_data: CSV_Dataset_DTO, +async def upload_dataset( + file: UploadFile = File(...), + name: str = Form(..., min_length=1), + description: str = Form(default="", min_length=1), + license: str = Form(default=""), + reference: str = Form(default=""), dataset_service: DatasetService = Depends(Provide[Container.dataset_service]), ) -> str: """ @@ -32,7 +39,24 @@ def convert_dataset( An error with status code 400 if there is a validation error. An error with status code 500 for any other server-side error. """ + + # Create a temporary file with a secure random name + with tempfile.NamedTemporaryFile( + delete=False, suffix=os.path.splitext(file.filename)[1] + ) as tmp_file: + content = await file.read() + tmp_file.write(content) + temp_file_path = tmp_file.name + try: + # Create the DTO with the form data including optional fields + dataset_data = CSV_Dataset_DTO( + name=name, + description=description, + license=license, + reference=reference, + file_path=temp_file_path, + ) return dataset_service.convert_dataset(dataset_data) except ServiceException as e: if e.error_code == "FileNotFound": @@ -47,6 +71,10 @@ def convert_dataset( raise HTTPException( status_code=500, detail=f"Failed to convert dataset: {e.msg}" ) + finally: + # Clean up the temporary file + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) @router.post("/api/v1/datasets/hf") diff --git a/moonshot/integrations/web_api/schemas/dataset_create_dto.py b/moonshot/integrations/web_api/schemas/dataset_create_dto.py index 0ca6bc68..d4f0e100 100644 --- a/moonshot/integrations/web_api/schemas/dataset_create_dto.py +++ b/moonshot/integrations/web_api/schemas/dataset_create_dto.py @@ -8,13 +8,13 @@ class CSV_Dataset_DTO(DatasetPydanticModel): - id: Optional[str] = None # Not a required from user - examples: Optional[Any] = None # Not a required from user + id: Optional[str] = None # Not required from user + examples: Optional[Any] = None # Not required from user name: str = Field(..., min_length=1) description: str = Field(default="", min_length=1) license: Optional[str] = "" reference: Optional[str] = "" - csv_file_path: str = Field(..., min_length=1) + file_path: str = Field(..., min_length=1) class HF_Dataset_DTO(DatasetPydanticModel): diff --git a/moonshot/integrations/web_api/services/dataset_service.py b/moonshot/integrations/web_api/services/dataset_service.py index d3024bfb..37944c9f 100644 --- a/moonshot/integrations/web_api/services/dataset_service.py +++ b/moonshot/integrations/web_api/services/dataset_service.py @@ -4,6 +4,7 @@ from ..services.base_service import BaseService from ..services.utils.exceptions_handler import exception_handler from .utils.file_manager import copy_file +import os class DatasetService(BaseService): @@ -16,7 +17,7 @@ def convert_dataset(self, dataset_data: CSV_Dataset_DTO) -> str: dataset_data (CSV_Dataset_DTO): The data required to convert the dataset. Returns: - str: The path to the newly created dataset. + str: The filename of the newly created dataset. Raises: Exception: If an error occurs during dataset conversion. @@ -27,9 +28,9 @@ def convert_dataset(self, dataset_data: CSV_Dataset_DTO) -> str: description=dataset_data.description, reference=dataset_data.reference, license=dataset_data.license, - csv_file_path=dataset_data.csv_file_path, + file_path=dataset_data.file_path, ) - return copy_file(new_ds_path) + return os.path.splitext(os.path.basename(new_ds_path))[0] @exception_handler def download_dataset(self, dataset_data: HF_Dataset_DTO) -> str: diff --git a/moonshot/src/api/api_dataset.py b/moonshot/src/api/api_dataset.py index ed6bf098..141180e2 100644 --- a/moonshot/src/api/api_dataset.py +++ b/moonshot/src/api/api_dataset.py @@ -1,3 +1,6 @@ +import json +import os + from pydantic import validate_call from moonshot.src.datasets.dataset import Dataset @@ -81,10 +84,10 @@ def api_download_dataset( def api_convert_dataset( - name: str, description: str, reference: str, license: str, csv_file_path: str + name: str, description: str, reference: str, license: str, file_path: str ) -> str: """ - Converts a CSV file to a dataset and creates a new dataset with the provided details. + Converts a CSV or JSON file to a dataset and creates a new dataset with the provided details. This function takes the name, description, reference, and license for a new dataset as input, along with the file path to a CSV file. It then creates a new DatasetArguments object with these details and an empty id. The id is left @@ -96,18 +99,55 @@ def api_convert_dataset( description (str): A brief description of the new dataset. reference (str): A reference link for the new dataset. license (str): The license of the new dataset. - csv_file_path (str): The file path to the CSV file. + file_path (str): The file path to the CSV or JSONfile. Returns: str: The ID of the newly created dataset. """ - examples = Dataset.convert_data(csv_file_path) - ds_args = DatasetArguments( - id="", - name=name, - description=description, - reference=reference, - license=license, - examples=examples, - ) + ds_args = None + + # Check if file is in a supported format + if not (file_path.endswith(".json") or file_path.endswith(".csv")): + raise ValueError("Unsupported file format. Please provide a JSON or CSV file.") + + # Check that file is not empty + if os.path.getsize(file_path) == 0: + raise ValueError("The uploaded file is empty.") + + # if file is already in json format + if file_path.endswith(".json"): + json_data = json.load(open(file_path)) + + try: + if "examples" in json_data and json_data["examples"]: + ds_args = DatasetArguments( + id="", + name=json_data.get("name", name), + description=json_data.get("description", description), + reference=json_data.get("reference", reference), + license=json_data.get("license", license), + examples=iter(json_data["examples"]), + ) + else: + raise KeyError( + "examples is either empty or this key is not in the JSON file. " + "Please ensure that this field is present." + ) + except Exception as e: + raise e + + # if file is in csv format, convert data + else: + try: + examples = Dataset.convert_data(file_path) + ds_args = DatasetArguments( + id="", + name=name, + description=description, + reference=reference, + license=license, + examples=examples, + ) + except Exception as e: + raise e return Dataset.create(ds_args) diff --git a/moonshot/src/datasets/dataset.py b/moonshot/src/datasets/dataset.py index 5703376a..2e077cea 100644 --- a/moonshot/src/datasets/dataset.py +++ b/moonshot/src/datasets/dataset.py @@ -60,7 +60,6 @@ def create(ds_args: DatasetArguments) -> str: } examples = ds_args.examples - # Write as JSON output file_path = Storage.create_object_with_iterator( EnvVariables.DATASETS.name, @@ -91,9 +90,26 @@ def convert_data(csv_file_path: str) -> Iterator[dict]: Returns: Iterator[dict]: An iterator of dictionaries representing the CSV data. """ + # validate headers + df_header = pd.read_csv(csv_file_path, nrows=1) + headers = df_header.columns.tolist() + required_headers = ["input", "target"] + if not all(header in headers for header in required_headers): + raise KeyError( + f"Required headers not found in the dataset. Required headers are {required_headers}." + ) + df = pd.read_csv(csv_file_path, chunksize=1) - for chunk in df: - yield chunk.to_dict("records")[0] + # validate dataset + first_chunk = next(df, None) + if first_chunk is None or first_chunk.empty: + raise ValueError("The uploaded file does not contain any data.") + + # Reset df after performing next(df) + df = pd.read_csv(csv_file_path, chunksize=1) + + result = [chunk.to_dict("records")[0] for chunk in df] + return iter(result) @staticmethod @validate_call diff --git a/pyproject.toml b/pyproject.toml index 48c7fdba..5504d8da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "datasets>=2.21.0", "pandas>=2.2.2", "numpy>=1.26.4", - "tenacity>=8.5.0" + "tenacity>=8.5.0", + "python-multipart>=0.0.9", ] [project.optional-dependencies] diff --git a/requirements.txt b/requirements.txt index 389c0dce..a1a47a98 100644 --- a/requirements.txt +++ b/requirements.txt @@ -40,6 +40,7 @@ pyreadline3~=3.4.1 ; python_version >= "3.11" and python_version < "3.12" and sy python-dateutil~=2.9.0.post0 ; python_version >= "3.11" and python_version < "3.12" python-dotenv~=1.0.1 ; python_version >= "3.11" and python_version < "3.12" python-slugify~=8.0.4 ; python_version >= "3.11" and python_version < "3.12" +python-multipart~=0.0.9 ; python_version >= "3.11" and python_version < "3.12" pytz~=2024.2 ; python_version >= "3.11" and python_version < "3.12" pyyaml~=6.0.2 ; python_version >= "3.11" and python_version < "3.12" requests~=2.32.3 ; python_version >= "3.11" and python_version < "3.12" diff --git a/tests/unit-tests/src/test_api_dataset.py b/tests/unit-tests/src/test_api_dataset.py index 6c35f432..77f25426 100644 --- a/tests/unit-tests/src/test_api_dataset.py +++ b/tests/unit-tests/src/test_api_dataset.py @@ -212,7 +212,7 @@ def test_api_delete_dataset(self, dataset_id, expected_dict): "reference": "www.reference.com", "license": "LICENSE", "method": "csv", - "csv_file_path": "tests/unit-tests/common/samples/sample-dataset.csv" + "file_path": "tests/unit-tests/common/samples/sample-dataset.csv" }, "tests/unit-tests/src/data/datasets/test-csv-dataset.json" ) @@ -235,10 +235,10 @@ def test_api_convert_dataset(self, dataset_details, expected_result): description = dataset_details.pop('description') reference = dataset_details.pop('reference') license = dataset_details.pop('license') - csv_file_path= dataset_details.pop('csv_file_path') + file_path= dataset_details.pop('file_path') # Call the api_convert_dataset function with unpacked arguments - result = api_convert_dataset(name, description, reference, license, csv_file_path) + result = api_convert_dataset(name, description, reference, license, file_path) # Assert that the result matches the expected result assert result == expected_result, f"The result '{result}' does not match the expected result '{expected_result}'." diff --git a/tests/unit-tests/web_api/test_routes/test_routes_dataset.py b/tests/unit-tests/web_api/test_routes/test_routes_dataset.py index 769509df..96483227 100644 --- a/tests/unit-tests/web_api/test_routes/test_routes_dataset.py +++ b/tests/unit-tests/web_api/test_routes/test_routes_dataset.py @@ -169,68 +169,78 @@ def test_download_dataset(test_client, mock_dataset_service, dataset_data, excep assert response.json() == expected_response -@pytest.mark.parametrize("dataset_data, exception, expected_status, expected_response", [ - # Successful case for "csv" method +@pytest.mark.parametrize("file_name, form_data, exception, expected_status, expected_response", [ + # Successful case ( + "test.csv", { "name": "New Dataset", "description": "This dataset is created from postman", "license": "norman license", "reference": "reference.com", - "csv_file_path": "/Users/normanchia/LocalDocs/your_dataset.csv", }, None, 200, "Dataset created successfully" ), - # Exception case for validation error + # Exception cases ( + "test.csv", { "name": "New Dataset", "description": "This dataset is created from postman", "license": "norman license", "reference": "reference.com", - "csv_file_path": "/Users/normanchia/LocalDocs/your_dataset.csv", }, ServiceException("A validation error occurred", "create_dataset", "ValidationError"), 400, {'detail': 'Failed to convert dataset: [ServiceException] ValidationError in create_dataset - A validation error occurred'} ), - # Exception case for file not found error ( + "test.csv", { "name": "New Dataset", "description": "This dataset is created from postman", "license": "norman license", "reference": "reference.com", - "csv_file_path": "/Users/normanchia/LocalDocs/your_dataset.csv", }, ServiceException("A file not found error occurred", "create_dataset", "FileNotFound"), 404, {'detail': 'Failed to convert dataset: [ServiceException] FileNotFound in create_dataset - A file not found error occurred'} ), - # Exception case for server error ( + "test.csv", { "name": "New Dataset", "description": "This dataset is created from postman", "license": "norman license", "reference": "reference.com", - "csv_file_path": "/Users/normanchia/LocalDocs/your_dataset.csv", }, ServiceException("An server error occurred", "create_dataset", "ServerError"), 500, {'detail': 'Failed to convert dataset: [ServiceException] ServerError in create_dataset - An server error occurred'} ), ]) -def test_convert_dataset(test_client, mock_dataset_service, dataset_data, exception, expected_status, expected_response, mocker): +def test_convert_dataset(test_client, mock_dataset_service, file_name, form_data, exception, expected_status, expected_response, mocker): mocker.patch("moonshot.integrations.web_api.routes.dataset.Provide", return_value=mock_dataset_service) + # Create a mock file content + file_content = b"mock,csv,content\n1,2,3" + files = { + "file": (file_name, file_content, "text/csv") + } + if exception: mock_dataset_service.convert_dataset.side_effect = exception else: mock_dataset_service.convert_dataset.return_value = expected_response - response = test_client.post("/api/v1/datasets/csv", json=dataset_data) + # Use test client with form data and files + response = test_client.post( + "/api/v1/datasets/file", + data=form_data, + files=files + ) + assert response.status_code == expected_status assert response.json() == expected_response \ No newline at end of file diff --git a/tests/unit-tests/web_api/test_services/test_service_dataset.py b/tests/unit-tests/web_api/test_services/test_service_dataset.py index 612a1c7b..4f11e5ac 100644 --- a/tests/unit-tests/web_api/test_services/test_service_dataset.py +++ b/tests/unit-tests/web_api/test_services/test_service_dataset.py @@ -43,7 +43,7 @@ description="This dataset is created from postman", reference="reference.com", license="license", - csv_file_path="tests/unit-tests/common/samples/sample-dataset.csv" + file_path="tests/unit-tests/common/samples/sample-dataset.csv" ) # Exception scenarios to test @@ -159,12 +159,11 @@ def test_convert_dataset_success(mock_copy_file, mock_moonshot_api): dataset_service = DatasetService() result = dataset_service.convert_dataset(MOCK_DATASET_CREATE_DTO_CSV) - assert result == "Dataset created successfully" + assert result == "dataset" mock_moonshot_api.api_convert_dataset.assert_called_once_with( name="New Dataset", description="This dataset is created from postman", reference="reference.com", license="license", - csv_file_path="tests/unit-tests/common/samples/sample-dataset.csv" - ) - mock_copy_file.assert_called_once_with("/path/to/new/dataset") \ No newline at end of file + file_path="tests/unit-tests/common/samples/sample-dataset.csv" + ) \ No newline at end of file