diff --git a/docs/examples/plot_lcppn_explainer.py b/docs/examples/plot_lcppn_explainer.py index b4864541..c9d3dad0 100644 --- a/docs/examples/plot_lcppn_explainer.py +++ b/docs/examples/plot_lcppn_explainer.py @@ -10,25 +10,11 @@ """ from sklearn.ensemble import RandomForestClassifier from hiclass import LocalClassifierPerParentNode, Explainer -import requests -import pandas as pd import shap +from hiclass.datasets import load_platypus -# Download training data -url = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" -path = "platypus_diseases.csv" -response = requests.get(url) -with open(path, "wb") as file: - file.write(response.content) - -# Load training data into pandas dataframe -training_data = pd.read_csv(path).fillna(" ") - -# Define data -X_train = training_data.drop(["label"], axis=1) -X_test = X_train[:100] # Use first 100 samples as test set -Y_train = training_data["label"] -Y_train = [eval(my) for my in Y_train] +# Load train and test splits +X_train, X_test, Y_train, Y_test = load_platypus() # Use random forest classifiers for every node rfc = RandomForestClassifier() diff --git a/docs/examples/plot_parallel_training.py b/docs/examples/plot_parallel_training.py index bc18e7e1..db875794 100644 --- a/docs/examples/plot_parallel_training.py +++ b/docs/examples/plot_parallel_training.py @@ -17,25 +17,15 @@ """ import sys from os import cpu_count - -import pandas as pd -import requests from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer from sklearn.linear_model import LogisticRegression from sklearn.pipeline import Pipeline from hiclass import LocalClassifierPerParentNode +from hiclass.datasets import load_hierarchical_text_classification - -# Download training data -url = "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" -path = "train_40k.csv" -response = requests.get(url) -with open(path, "wb") as file: - file.write(response.content) - -# Load training data into pandas dataframe -training_data = pd.read_csv(path).fillna(" ") +# Load train and test splits +X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() # We will use logistic regression classifiers for every parent node lr = LogisticRegression(max_iter=1000) @@ -51,10 +41,6 @@ ] ) -# Select training data -X_train = training_data["Title"] -Y_train = training_data[["Cat1", "Cat2", "Cat3"]] - # Fixes bug AttributeError: '_LoggingTee' object has no attribute 'fileno' # This only happens when building the documentation # Hence, you don't actually need it for your code to work diff --git a/docs/source/api/utilities.rst b/docs/source/api/utilities.rst index 2ba43849..faf790f9 100644 --- a/docs/source/api/utilities.rst +++ b/docs/source/api/utilities.rst @@ -88,3 +88,23 @@ F-score ^^^^^^^ .. autofunction:: metrics.f1 + +.................................. + + +Datasets +---------- + +Platypus diseases dataset +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: datasets.load_platypus + +.................................. + +Hierarchical text classification dataset +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autofunction:: datasets.load_hierarchical_text_classification + +.................................. diff --git a/hiclass/__init__.py b/hiclass/__init__.py index ec3db00e..884ba4cf 100644 --- a/hiclass/__init__.py +++ b/hiclass/__init__.py @@ -22,4 +22,5 @@ "Explainer", "MultiLabelLocalClassifierPerNode", "MultiLabelLocalClassifierPerParentNode", + "datasets", ] diff --git a/hiclass/datasets.py b/hiclass/datasets.py new file mode 100644 index 00000000..55819a8f --- /dev/null +++ b/hiclass/datasets.py @@ -0,0 +1,138 @@ +"""Datasets util for downloading and maintaining sample datasets.""" + +import requests +import pandas as pd +import os +import tempfile +import logging +from sklearn.model_selection import train_test_split + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Use temp directory to store cached datasets +CACHE_DIR = tempfile.gettempdir() + +# Ensure cache directory exists +os.makedirs(CACHE_DIR, exist_ok=True) + +# Dataset urls +PLATYPUS_URL = "https://gist.githubusercontent.com/ashishpatel16/9306f8ed3ed101e7ddcb519776bcbd80/raw/1152c0b9613c2bda144a38fc4f74b5fe12255f4d/platypus_diseases.csv" +HIERARCHICAL_TEXT_CLASSIFICATION_URL = ( + "https://zenodo.org/record/6657410/files/train_40k.csv?download=1" +) + + +def _download_file(url, destination): + """Download file from given URL to specified destination.""" + try: + response = requests.get(url) + # Raise HTTPError if response code is not OK + response.raise_for_status() + with open(destination, "wb") as f: + f.write(response.content) + except requests.RequestException as e: + raise RuntimeError(f"Failed to download file from {url}: {str(e)}") + + +def load_platypus(test_size=0.3, random_state=42): + """ + Load platypus diseases dataset. + + Parameters + ---------- + test_size : float, default=0.3 + The proportion of the dataset to include in the test split. + random_state : int or None, default=42 + Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. + + Returns + ------- + list + List containing train-test split of inputs. + + Raises + ------ + RuntimeError + If failed to access or process the dataset. + Examples + -------- + >>> from hiclass.datasets import load_platypus + >>> X_train, X_test, Y_train, Y_test = load_platypus() + >>> X_train[:3] + fever diarrhea stomach pain skin rash cough sniffles short breath headache size + 220 37.8 0 3 5 1 1 0 2 27.6 + 539 37.2 0 6 1 1 1 0 3 28.4 + 326 39.9 0 2 5 1 1 1 2 30.7 + >>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape + (572, 9) (246, 9) (572,) (246,) + """ + dataset_name = "platypus_diseases.csv" + cached_file_path = os.path.join(CACHE_DIR, dataset_name) + + # Check if the file exists in the cache + if not os.path.exists(cached_file_path): + try: + logger.info("Downloading platypus diseases dataset..") + _download_file(PLATYPUS_URL, cached_file_path) + except Exception as e: + raise RuntimeError(f"Failed to access or download dataset: {str(e)}") + + data = pd.read_csv(cached_file_path).fillna(" ") + X = data.drop(["label"], axis=1) + y = pd.Series([eval(val) for val in data["label"]]) + + # Return tuple (X_train, X_test, y_train, y_test) + return train_test_split(X, y, test_size=test_size, random_state=random_state) + + +def load_hierarchical_text_classification(test_size=0.3, random_state=42): + """ + Load hierarchical text classification dataset. + + Parameters + ---------- + test_size : float, default=0.3 + The proportion of the dataset to include in the test split. + random_state : int or None, default=42 + Controls the randomness of the dataset. Pass an int for reproducible output across multiple function calls. + + Returns + ------- + list + List containing train-test split of inputs. + + Raises + ------ + RuntimeError + If failed to access or process the dataset. + Examples + -------- + >>> from hiclass.datasets import load_hierarchical_text_classification + >>> X_train, X_test, Y_train, Y_test = load_hierarchical_text_classification() + >>> X_train[:3] + 38015 Nature's Way Selenium + 2281 Music In Motion Developmental Mobile W Remote + 36629 Twinings Ceylon Orange Pekoe Tea, Tea Bags, 20... + Name: Title, dtype: object + >>> X_train.shape, X_test.shape, Y_train.shape, Y_test.shape + (28000,) (12000,) (28000, 3) (12000, 3) + """ + dataset_name = "hierarchical_text_classification.csv" + cached_file_path = os.path.join(CACHE_DIR, dataset_name) + + # Check if the file exists in the cache + if not os.path.exists(cached_file_path): + try: + logger.info("Downloading hierarchical text classification dataset..") + _download_file(HIERARCHICAL_TEXT_CLASSIFICATION_URL, cached_file_path) + except Exception as e: + raise RuntimeError(f"Failed to access or download dataset: {str(e)}") + + data = pd.read_csv(cached_file_path).fillna(" ") + X = data["Title"] + y = data[["Cat1", "Cat2", "Cat3"]] + + # Return tuple (X_train, X_test, y_train, y_test) + return train_test_split(X, y, test_size=test_size, random_state=random_state) diff --git a/tests/test_Datasets.py b/tests/test_Datasets.py new file mode 100644 index 00000000..1d47ba5e --- /dev/null +++ b/tests/test_Datasets.py @@ -0,0 +1,121 @@ +import numpy as np +import pytest + +import hiclass.datasets +from hiclass.datasets import load_platypus, load_hierarchical_text_classification +import os +import tempfile + + +def test_load_platypus_output_shape(): + X_train, X_test, y_train, y_test = load_platypus(test_size=0.2, random_state=42) + assert X_train.shape[0] == y_train.shape[0] + assert X_test.shape[0] == y_test.shape[0] + + +def test_load_platypus_random_state(): + X_train_1, X_test_1, y_train_1, y_test_1 = load_platypus( + test_size=0.2, random_state=42 + ) + X_train_2, X_test_2, y_train_2, y_test_2 = load_platypus( + test_size=0.2, random_state=42 + ) + assert (X_train_1.values == X_train_2.values).all() + assert (X_test_1.values == X_test_2.values).all() + assert (y_train_1.index == y_train_2.index).all() + assert (y_test_1.index == y_test_2.index).all() + + +def test_load_hierarchical_text_classification_shape(): + X_train, X_test, y_train, y_test = load_hierarchical_text_classification( + test_size=0.2, random_state=42 + ) + assert X_train.shape[0] == y_train.shape[0] + assert X_test.shape[0] == y_test.shape[0] + + +def test_load_hierarchical_text_classification_random_state(): + X_train_1, X_test_1, y_train_1, y_test_1 = load_hierarchical_text_classification( + test_size=0.2, random_state=42 + ) + X_train_2, X_test_2, y_train_2, y_test_2 = load_hierarchical_text_classification( + test_size=0.2, random_state=42 + ) + assert (X_train_1 == X_train_2).all() + assert (X_test_1 == X_test_2).all() + assert (y_train_1.index == y_train_2.index).all() + assert (y_test_1.index == y_test_2.index).all() + + +def test_load_hierarchical_text_classification_file_exists(): + dataset_name = "hierarchical_text_classification.csv" + cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) + + if os.path.exists(cached_file_path): + os.remove(cached_file_path) + + if not os.path.exists(cached_file_path): + load_hierarchical_text_classification() + assert os.path.exists(cached_file_path) + + +def test_load_platypus_file_exists(): + dataset_name = "platypus_diseases.csv" + cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) + + if os.path.exists(cached_file_path): + os.remove(cached_file_path) + + if not os.path.exists(cached_file_path): + load_platypus() + assert os.path.exists(cached_file_path) + + +def test_download_dataset(): + dataset_name = "platypus_diseases_test.csv" + url = hiclass.datasets.PLATYPUS_URL + cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) + + if os.path.exists(cached_file_path): + os.remove(cached_file_path) + + if not os.path.exists(cached_file_path): + hiclass.datasets._download_file(url, cached_file_path) + assert os.path.exists(cached_file_path) + + +def test_download_error_load_platypus(): + dataset_name = "platypus_diseases.csv" + backup_url = hiclass.datasets.PLATYPUS_URL + hiclass.datasets.PLATYPUS_URL = "" + cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) + + if os.path.exists(cached_file_path): + os.remove(cached_file_path) + + if not os.path.exists(cached_file_path): + with pytest.raises(RuntimeError): + load_platypus() + + hiclass.datasets.PLATYPUS_URL = backup_url + + +def test_download_error_load_hierarchical_text(): + dataset_name = "hierarchical_text_classification.csv" + backup_url = hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL + hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = "" + cached_file_path = os.path.join(tempfile.gettempdir(), dataset_name) + + if os.path.exists(cached_file_path): + os.remove(cached_file_path) + + if not os.path.exists(cached_file_path): + with pytest.raises(RuntimeError): + load_hierarchical_text_classification() + + hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL = backup_url + + +def test_url_links(): + assert hiclass.datasets.PLATYPUS_URL != "" + assert hiclass.datasets.HIERARCHICAL_TEXT_CLASSIFICATION_URL != ""