Skip to content

Commit

Permalink
Global rename of data_module to module_name
Browse files Browse the repository at this point in the history
  • Loading branch information
grovduck committed Jul 10, 2024
1 parent e79709b commit e73846f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/sknnr/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ def _open_text(module_name: str | types.ModuleType, file_name: str) -> IO[str]:


def load_csv_data(
file_name: str, *, data_module: str | types.ModuleType = DATA_MODULE
file_name: str, *, module_name: str | types.ModuleType = DATA_MODULE
) -> tuple[NDArray[np.int64], NDArray[np.float64], NDArray[np.str_]]:
"""Load data from a CSV file from the specified data_module.
"""Load data from a CSV file from the specified module_name.
Parameters
----------
file_name: str, required
The filename of the CSV file to load from `data_module/file_name`.
data_module: str or module, default='sknnr.datasets.data'
The filename of the CSV file to load from `module_name/file_name`.
module_name: str or module, default='sknnr.datasets.data'
The module where the data file is located.
Returns
Expand All @@ -93,7 +93,7 @@ def load_csv_data(
The CSV must be formatted with plot IDs in the first column and data values in the
remaining columns. The first row must contain the column names.
"""
with _open_text(data_module, file_name) as csv_file:
with _open_text(module_name, file_name) as csv_file:
data_file = csv.reader(csv_file)
headers = next(data_file)
rows = list(iter(data_file))
Expand All @@ -111,7 +111,7 @@ def load_dataset_from_csv_filenames(
target_filename: str,
return_X_y: bool = False,
as_frame: bool = False,
data_module: str | types.ModuleType = DATA_MODULE,
module_name: str | types.ModuleType = DATA_MODULE,
) -> tuple[NDArray[np.float64], NDArray[np.float64]] | Dataset:
"""Load separate data and target CSV files into a dataset or paired NumPy arrays.
Expand All @@ -128,7 +128,7 @@ def load_dataset_from_csv_filenames(
DataFrames instead of NumPy arrays. The `frame` attribute will also be added as
a DataFrame with the dataset index. Pandas must be installed for this
option.
data_module: str or module, default='sknnr.datasets.data'
module_name: str or module, default='sknnr.datasets.data'
The module where the data files are located.
Returns
Expand All @@ -144,10 +144,10 @@ def load_dataset_from_csv_filenames(
The plot IDs in each file are expected to match and be in the same order.
"""
index, data, feature_names = load_csv_data(
file_name=data_filename, data_module=data_module
file_name=data_filename, module_name=module_name
)
_, target, target_names = load_csv_data(
file_name=target_filename, data_module=data_module
file_name=target_filename, module_name=module_name
)

dataset = Dataset(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_passing_module_type():
"""Test that passing a module type to load_csv_data works."""
import sknnr.datasets.data as data_module

load_csv_data("moscow_env.csv", data_module=data_module)
load_csv_data("moscow_env.csv", module_name=data_module)


def test_incorrect_module_raises_on_load_csv():
"""Test that load_csv_data raises when given an invalid module name."""
invalid_module = "sknnr.datasets.invalid_module"
with pytest.raises(ModuleNotFoundError, match="No module named"):
load_csv_data("moscow_env.csv", data_module=invalid_module)
load_csv_data("moscow_env.csv", module_name=invalid_module)


@pytest.mark.parametrize(
Expand Down

0 comments on commit e73846f

Please sign in to comment.