Skip to content

Commit

Permalink
Merge pull request #732 from PowerGridModel/fix/columnar-convert-muta…
Browse files Browse the repository at this point in the history
…ble-to-const-1

MutableDataset to ConstDataset for columnar data
  • Loading branch information
TonyXiang8787 authored Sep 24, 2024
2 parents 2843f29 + 24f1f55 commit 8d014bf
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,14 @@ template <dataset_type_tag dataset_type_> class Dataset {
: meta_data_{&other.meta_data()}, dataset_info_{other.get_description()} {
for (Idx i{}; i != other.n_components(); ++i) {
auto const& buffer = other.get_buffer(i);
buffers_.push_back(Buffer{.data = buffer.data, .indptr = buffer.indptr});
Buffer new_buffer{.data = buffer.data, .indptr = buffer.indptr};
for (auto const& attribute_buffer : buffer.attributes) {

AttributeBuffer<Data> const new_attribute_buffer{.data = attribute_buffer.data,
.meta_attribute = attribute_buffer.meta_attribute};
new_buffer.attributes.emplace_back(new_attribute_buffer);
}
buffers_.push_back(new_buffer);
}
}

Expand Down
37 changes: 22 additions & 15 deletions tests/cpp_unit_tests/test_dataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,9 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa

add_inhomogeneous_buffer(dataset, A::name, total_elements, a_indptr.data(), nullptr);

auto const check_span = [&](auto const& buffer_span, Idx const& scenario = -1) {
auto const check_span = [&total_elements, &elements_per_scenarios, &a_indptr, &id_buffer,
&a1_buffer]<typename T>(auto const& buffer_span,
Idx const& scenario = -1) {
auto element_number = total_elements;
Idx aux_idx = 0;
if (scenario != -1) {
Expand All @@ -880,45 +882,50 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
}
CHECK(buffer_span.size() == element_number);
for (Idx idx = 0; idx < buffer_span.size(); ++idx) {
auto const element = get_colummnar_element<DatasetType>(buffer_span, idx);
auto const element = get_colummnar_element<T>(buffer_span, idx);
CHECK(element.id == id_buffer[aux_idx + idx]);
CHECK(element.a1 == a1_buffer[aux_idx + idx]);
CHECK(is_nan(element.a0));
}
};
auto const check_all_spans = [&](auto const& scenario) {
check_span(dataset.template get_columnar_buffer_span<input_getter_s, A>());
check_span(
dataset.template get_columnar_buffer_span<input_getter_s, A>(DatasetType::invalid_index));
auto const check_all_spans = [&check_span, &batch_size]<typename T>(auto& any_dataset,
auto const& scenario) {
check_span.template operator()<T>(
any_dataset.template get_columnar_buffer_span<input_getter_s, A>());
check_span.template operator()<T>(
any_dataset.template get_columnar_buffer_span<input_getter_s, A>(T::invalid_index));

auto const all_scenario_spans =
dataset.template get_columnar_buffer_span_all_scenarios<input_getter_s, A>();
any_dataset.template get_columnar_buffer_span_all_scenarios<input_getter_s, A>();
CHECK(all_scenario_spans.size() == batch_size);

auto const scenario_span =
dataset.template get_columnar_buffer_span<input_getter_s, A>(scenario);
check_span(scenario_span, scenario);
any_dataset.template get_columnar_buffer_span<input_getter_s, A>(scenario);
check_span.template operator()<T>(scenario_span, scenario);
CHECK(all_scenario_spans[scenario].size() == scenario_span.size());
check_span(all_scenario_spans[scenario], scenario);
check_span.template operator()<T>(all_scenario_spans[scenario], scenario);
};
add_attribute_buffer(dataset, A::name, A::InputType::a1_name, a1_buffer.data());
add_attribute_buffer(dataset, A::name, A::InputType::id_name, id_buffer.data());
for (Idx scenario : {0, 1, 2, 3}) {
CAPTURE(scenario);
if (scenario < batch_size) {
check_all_spans(scenario);
check_all_spans.template operator()<DatasetType>(dataset, scenario);

std::ranges::fill(id_buffer, 1);
check_all_spans(scenario);
check_all_spans.template operator()<DatasetType>(dataset, scenario);

std::transform(boost::counting_iterator<ID>{0},
boost::counting_iterator<ID>{static_cast<ID>(total_elements)},
id_buffer.begin(), [](ID value) { return value * 2; });
check_all_spans(scenario);
check_all_spans.template operator()<DatasetType>(dataset, scenario);

std::ranges::transform(id_buffer, a1_buffer.begin(),
[](ID value) { return static_cast<double>(value); });
check_all_spans(scenario);
check_all_spans.template operator()<DatasetType>(dataset, scenario);

auto dataset_copy = ConstDataset{dataset};
check_all_spans.template operator()<ConstDataset>(dataset_copy, scenario);

if constexpr (!std::same_as<DatasetType, ConstDataset>) {
auto buffer_span =
Expand All @@ -928,7 +935,7 @@ TEST_CASE_TEMPLATE("Test dataset (common)", DatasetType, ConstDataset, MutableDa
buffer_span[idx] = A::InputType{.id = -10, .a0 = -1.0, .a1 = -2.0};
CHECK(id_buffer[idx + (a_indptr[scenario])] == -10);
CHECK(a1_buffer[idx + (a_indptr[scenario])] == -2.0);
check_all_spans(scenario);
check_all_spans.template operator()<DatasetType>(dataset, scenario);
}
}
}
Expand Down

0 comments on commit 8d014bf

Please sign in to comment.