From 688a0f806653056c6736d68dd2e9141747a89d7a Mon Sep 17 00:00:00 2001 From: Santiago Figueroa Manrique Date: Mon, 30 Sep 2024 17:31:07 +0200 Subject: [PATCH] improved batch validation for columnar data Signed-off-by: Santiago Figueroa Manrique --- src/power_grid_model/validation/validation.py | 15 +++++---------- tests/unit/validation/test_batch_validation.py | 2 +- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/power_grid_model/validation/validation.py b/src/power_grid_model/validation/validation.py index e904f63fd..98bb05ee3 100644 --- a/src/power_grid_model/validation/validation.py +++ b/src/power_grid_model/validation/validation.py @@ -135,22 +135,17 @@ def validate_batch_data( input_errors: list[ValidationError] = list(validate_unique_ids_across_components(input_data)) - # Convert to row based if in columnar format - # TODO(figueroa1395): transform to columnar per single batch scenario once the columnar dataset python extension - # is finished - row_update_data = compatibility_convert_row_columnar_dataset(update_data, None, DatasetType.update) - - # Splitting update_data_into_batches may raise TypeErrors and ValueErrors - batch_data = convert_batch_dataset_to_batch_list(row_update_data) + batch_data = convert_batch_dataset_to_batch_list(update_data) errors = {} for batch, batch_update_data in enumerate(batch_data): - assert_valid_data_structure(batch_update_data, DatasetType.update) - id_errors: list[ValidationError] = list(validate_ids_exist(batch_update_data, input_data)) + row_update_data = compatibility_convert_row_columnar_dataset(batch_update_data, None, DatasetType.update) + assert_valid_data_structure(row_update_data, DatasetType.update) + id_errors: list[ValidationError] = list(validate_ids_exist(row_update_data, input_data)) batch_errors = input_errors + id_errors if not id_errors: - merged_data = update_input_data(input_data, batch_update_data) + merged_data = update_input_data(input_data, row_update_data) batch_errors += validate_required_values(merged_data, calculation_type, symmetric) batch_errors += validate_values(merged_data, calculation_type) diff --git a/tests/unit/validation/test_batch_validation.py b/tests/unit/validation/test_batch_validation.py index ca9e55a1d..4fac6d858 100644 --- a/tests/unit/validation/test_batch_validation.py +++ b/tests/unit/validation/test_batch_validation.py @@ -80,7 +80,7 @@ def test_validate_batch_data_input_error(input_data, batch_data): def test_validate_batch_data_update_error(input_data, batch_data): - batch_data["line"]["from_status"] = [[12, 34], [0, -128], [56, 78]] + batch_data["line"]["from_status"] = np.array([[12, 34], [0, -128], [56, 78]]) errors = validate_batch_data(input_data, batch_data) assert len(errors) == 2 assert [NotBooleanError("line", "from_status", [5, 6])] == errors[0]