Skip to content

Commit

Permalink
Merge pull request #9 from bigladder/return-normalization-factor
Browse files Browse the repository at this point in the history
Add return types for normalization factor.
  • Loading branch information
nealkruis authored Dec 16, 2019
2 parents a95d109 + 4634057 commit 692bcce
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 12 deletions.
8 changes: 4 additions & 4 deletions src/btwxt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ std::vector<double> RegularGridInterpolator::get_values_at_target() {
return grid_point.get_results();
}

void RegularGridInterpolator::normalize_values_at_target(std::size_t table_index, const std::vector<double> &target, const double scalar) {
double RegularGridInterpolator::normalize_values_at_target(std::size_t table_index, const std::vector<double> &target, const double scalar) {
set_new_target(target);
normalize_values_at_target(table_index, scalar);
return normalize_values_at_target(table_index, scalar);
}

void RegularGridInterpolator::normalize_values_at_target(std::size_t table_index, const double scalar) {
grid_point.normalize_grid_values_at_target(table_index, scalar);
double RegularGridInterpolator::normalize_values_at_target(std::size_t table_index, const double scalar) {
return grid_point.normalize_grid_values_at_target(table_index, scalar);
}

void RegularGridInterpolator::normalize_values_at_target(const std::vector<double> &target, const double scalar) {
Expand Down
4 changes: 2 additions & 2 deletions src/btwxt.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ class RegularGridInterpolator {

void normalize_values_at_target(const std::vector<double> &target, const double scalar = 1.0);

void normalize_values_at_target(std::size_t table_index, const double scalar = 1.0);
double normalize_values_at_target(std::size_t table_index, const double scalar = 1.0);

void normalize_values_at_target(std::size_t table_index, const std::vector<double> &target, const double scalar = 1.0);
double normalize_values_at_target(std::size_t table_index, const std::vector<double> &target, const double scalar = 1.0);

std::vector<double> get_current_target();

Expand Down
10 changes: 7 additions & 3 deletions src/gridpoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,15 +294,19 @@ void GridPoint::normalize_grid_values_at_target(const double scalar) {
set_results();
}

void GridPoint::normalize_grid_values_at_target(std::size_t table_num, const double scalar) {
double GridPoint::normalize_grid_values_at_target(std::size_t table_num, const double scalar) {
if (!target_is_set) {
showMessage(MsgLevel::MSG_WARN,
stringify("Cannot normalize grid values. No target has been set."));
return;
return scalar;
}
grid_data->normalize_value_table(table_num,results[table_num]*scalar);
// create a scalar which represents the product of the inverted normalization factor and the value in the table at the independent variable reference value
double total_scalar = results[table_num]*scalar;
grid_data->normalize_value_table(table_num,total_scalar);
hypercube_cache.clear();
set_results();

return total_scalar;
}

} // namespace Btwxt
4 changes: 2 additions & 2 deletions src/gridpoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ class GridPoint {

double get_vertex_weight(const std::vector<short> &v);

void normalize_grid_values_at_target(std::size_t table_num, const double scalar = 1.0);

void normalize_grid_values_at_target(const double scalar = 1.0);

double normalize_grid_values_at_target(std::size_t table_num, const double scalar = 1.0);

void set_floor();

private:
Expand Down
25 changes: 25 additions & 0 deletions test/btwxt_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,28 @@ TEST_F(TwoDFixture, normalize) {
EXPECT_THAT(result, testing::ElementsAre(testing::DoubleEq(1.0), testing::DoubleEq(8.832)));
Btwxt::LOG_LEVEL = 1;
}

TEST_F(TwoDSimpleNormalizationFixture, normalization_return_scalar) {
std::vector<double> target {7.0, 3.0};
std::vector<double> normalization_target = {2.0, 3.0};
double expected_divisor {test_function(normalization_target)};
double expected_value_at_target {test_function(target)/expected_divisor};
double return_scalar = test_rgi.normalize_values_at_target(0, normalization_target, 1.0);
test_rgi.set_new_target(target);
std::vector<double> results = test_rgi.get_values_at_target();
EXPECT_THAT(return_scalar, testing::DoubleEq(expected_divisor));
EXPECT_THAT(results, testing::ElementsAre(expected_value_at_target));
}

TEST_F(TwoDSimpleNormalizationFixture, normalization_return_compound_scalar) {
std::vector<double> target {7.0, 3.0};
std::vector<double> normalization_target = {2.0, 3.0};
double normalization_divisor = 4.0;
double expected_compound_divisor {test_function(normalization_target)*normalization_divisor};
double expected_value_at_target {test_function(target)/expected_compound_divisor};
double return_scalar = test_rgi.normalize_values_at_target(0, normalization_target, normalization_divisor);
test_rgi.set_new_target(target);
std::vector<double> results = test_rgi.get_values_at_target();
EXPECT_THAT(return_scalar, testing::DoubleEq(expected_compound_divisor));
EXPECT_THAT(results, testing::ElementsAre(expected_value_at_target));
}
28 changes: 27 additions & 1 deletion test/fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class TwoDFixture : public testing::Test {

TwoDFixture() {
std::vector<std::vector<double>> grid = {{0, 10, 15}, {4, 6}};
// 4 6
// 4 6
values = {{6, 3, // 0
2, 8, // 10
4, 2}, // 15
Expand All @@ -106,6 +106,32 @@ class TwoDFixture : public testing::Test {
}
};

class TwoDSimpleNormalizationFixture : public testing::Test {
// TODO: Create a fixture which this one can inherit from
// takes a vector of functions as a parameter (these become separate value tables)
// takes a vector of vectors which is the data structure that stores the grid
protected:
RegularGridInterpolator test_rgi;
GriddedData test_gridded_data;
double test_function (std::vector<double> target){
assert(target.size() == 2);
return target[0]*target[1];
}

TwoDSimpleNormalizationFixture() {
std::vector<std::vector<double>> grid = {{2.0, 7.0}, {1.0, 2.0, 3.0}};
std::vector<double> values;
for (auto x : grid[0]){
for (auto y : grid[1] ){
values.push_back(test_function({x,y}));
}
}
test_gridded_data = GriddedData(grid, {values});
test_gridded_data.set_axis_extrap_method(0, Method::LINEAR);
test_rgi = RegularGridInterpolator(test_gridded_data);
}
};

class CubicFixture : public testing::Test {
protected:
RegularGridInterpolator test_rgi;
Expand Down

0 comments on commit 692bcce

Please sign in to comment.